Gemma 2のファインチューニングとKerasの活用法
2024年6月27日、Kerasのプロダクトマネージャーであるマーチン・ゴーナー氏が発表したように、新しいGemma 2モデルがKeras上で利用可能になりました。このモデルは、9Bおよび27Bパラメータという2つのサイズで展開されており、標準およびインストラクションチューニング済みのバリエーションを含んでいます。
Gemma 2について知っておくべきこと
- Gemma 2は、優れた文脈理解能力を持ち、多くのLLMベンチマークでの優れた結果が報告されています。
- このモデルは、KaggleやHugging Faceからアクセスすることができます。
- Gemmaの新しいバージョンは、従来のモデルを進化させ、多くの機能強化が加えられています。
Gemma 2は、JAXというスケーラブルな数値計算フレームワークを利用しています。JAXはXLAという機械学習コンパイラを用いて、Googleで最も大きなモデルのトレーニングをサポートしています。また、KerasはMLエンジニア向けのモデリングフレームワークで、JAX、TensorFlow、PyTorchの上で動作します。Kerasは、モデルのパラレルスケーリングの力を、使いやすいKeras APIを通じて提供します。
大規模モデルの利用: モデル並列化
Gemma 2のような大型モデルは、そのサイズのために、加速器間でウェイトを分割することでのみフル精度でロードおよびファインチューニングが可能になります。JAXとXLAは、ウェイトのパーティショニング(SPMDモデルパラレル処理)を広くサポートしており、Kerasは.keras.distribution.ModelParallel APIを提供して、レイヤーごとにシャーディングを指定するのを簡単にしています。
- デバイスの一覧: keras.distribution.list_devices() を使用してデバイスを確認。
- デバイスメッシュの定義: 論理グリッドを設定し、シャーディングを指定します。
- モデルの並列分配の設定: keras.distribution.ModelParallelを使用してモデルの設定を行います。
- モデルのロード: gemma2_lmの設定を行い、モデルをインスタンス化します。
これにより、各レイヤーのウェイトをどの加速器に置くかが簡単に定義でき、効率的なトレーニングが可能になります。また、LoRAと呼ばれる技術を使用することで、モデルの一部のパラメータを凍結し、低ランクのアダプターに置き換えることも可能です。これにより、Gemma 9Bの訓練可能なパラメータの数を90億から1450万にまで削減することができます。
Hugging Faceとの統合の進化
最近の発表では、Hugging Faceの統合も強化され、ユーザーがKerasモデルのダウンロードやアップロードを簡単に行えるようになりました。これにより、Hugging Faceにアップロードされた多数のGemmaファインチューニングモデルを直接KerasNLPから利用できるようになります。この機能は、今後あらゆるHugging Face Transformersモデルに対応する予定です。
PaliGemmaの活用
PaliGemmaは、PaLI-3からインスパイアを受けた強力なオープンVLMです。SigLIPビジョンモデルとGemma言語モデルを基盤にしており、画像キャプショニングや視覚質問応答、画像内のテキスト理解、物体検出、物体セグメンテーションなど、幅広い視覚言語タスクに対して最高のファインチューニング性能を提供します。PaliGemmaのKeras実装も、GitHubやHugging Face、およびKaggleで利用可能です。
まとめ
新しいGemma 2モデルをKeras上で試験的に使ったり、構築したりすることができるようになりました。これにより、機械学習の世界において、より広範な利用が期待されています。この技術を駆使して、AIの未来を共に形作っていくことができるでしょう。