JAX と Flax を使用した Gemma のファインチューニング

ai.google.dev で表示 Google Colab で実行 Vertex AI で開く GitHub でソースを表示

概要

Gemma は、Google DeepMind Gemini の研究とテクノロジーをベースにした、軽量で最先端のオープン 大規模言語モデルのファミリーです。このチュートリアルでは、Google DeepMind の gemma ライブラリJAX(高パフォーマンス数値計算ライブラリ)、Flax(JAX ベースのニューラル ネットワーク ライブラリ)、Chex(信頼性の高い JAX および MT2 勾配変換データセット Nois2 および MT2 勾配変換用ユーティリティのライブラリ)、(Theこのノートブックでは Flax は直接使用されていませんが、Gemma の作成には Flax が使用されました。

gemma ライブラリは、JAX、Flax、Orbax(チェックポインティングなどのトレーニング ユーティリティ用の JAX ベースのライブラリ)、SentencePiece(トークナイザ/デトークナイザ ライブラリ)で記述されています。

設定

1. Gemma 用に Kaggle のアクセス権を設定する

このチュートリアルを完了するには、まず Gemma のセットアップに記載されている手順に沿って操作する必要があります。ここでは、以下の方法について説明しています。

  • kaggle.com で Gemma にアクセスできます。
  • Gemma モデルの実行に十分なリソースがある Colab ランタイムを選択します。
  • Kaggle ユーザー名と API キーを生成して構成します。

Gemma の設定が完了したら次のセクションに進み、Colab 環境の環境変数を設定します。

2. 環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。「アクセスを許可しますか?」というメッセージが表示されたら、シークレット アクセスの提供に同意します。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma ライブラリをインストールする

現在、このノートブックを実行するには、無料の Colab ハードウェア アクセラレーションでは不十分です。Colab の従量課金制または Colab Pro を使用している場合は、[編集] > [ノートブックの設定] をクリックし、[A100 GPU] を選択して [保存] を選択して、ハードウェア アクセラレーションを有効にします。

次に、github.com/google-deepmind/gemma から Google DeepMind gemma ライブラリをインストールする必要があります。「pip の依存関係リゾルバ」に関するエラーが発生した場合、通常は無視してかまいません。

pip install -q git+https://github.com/google-deepmind/gemma.git

4. ライブラリをインポートする

このノートブックでは、Flax(ニューラル ネットワーク用)、コア JAXSentencePiece(トークン化用)、Chex(信頼性の高い JAX コードを作成するためのユーティリティのライブラリ)、TensorFlow データセットを使用します。

import os
import enum
import re
import string

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

Gemma モデルを読み込む

kagglehub.model_download を使用して Gemma モデルを読み込みます。これは 3 つの引数を取ります。

  • handle: Kaggle のモデルハンドル
  • path: (省略可能な文字列)ローカルパス
  • force_download: (ブール値)(省略可)モデルを強制的に再ダウンロードします。
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2

モデルの重みとトークナイザの場所を確認してから、パス変数を設定します。トークナイザ ディレクトリは、モデルをダウンロードしたメイン ディレクトリにありますが、モデルの重みはサブディレクトリにあります。次に例を示します。

  • tokenizer.model ファイルは /LOCAL/PATH/TO/gemma/flax/2b-it/2 にあります)。
  • モデルのチェックポイントは /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it にあります)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

MTNT データセットと Gemma トークナイザを読み込んで準備する

TensorFlow Datasets から入手できる MTNT(Machine Translation of Noisy Text)データセットを使用します。

MTNT データセットの英語からフランス語への変換データセットをダウンロードし、次に 2 つの例をサンプリングします。データセット内の各サンプルには、src: 元の英語の文と dst: 対応するフランス語の翻訳の 2 つのエントリが含まれます。

ds = tfds.load("mtnt/en-fr", split="train")

ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/35692 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...:   0%|          …
Generating test examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...:   0%|          |…
Generating valid examples...:   0%|          | 0/811 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...:   0%|          …
Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.
Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'

sentencepiece.SentencePieceProcessor を使用して作成された Gemma トークナイザを読み込みます。

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

英語からフランス語への翻訳タスク用に SentencePieceProcessor をカスタマイズします。Gemma モデルの英語部分を微調整するため、次のようないくつかの調整を行う必要があります。

  • 入力接頭辞: 各入力に共通の接頭辞を追加すると、変換タスクに通知されます。たとえば、Translate this into French: [INPUT_SENTENCE] のような接頭辞を持つプロンプトを使用できます。

  • 翻訳開始接尾辞: 各プロンプトの末尾に接尾辞を追加することで、Gemma モデルに翻訳プロセスを開始する正確なタイミングを指示します。新しい行で実行されます。

  • 言語モデルのトークン: Gemma モデルは、各シーケンスの先頭に「シーケンスの先頭」トークンを想定しているため、各トレーニング サンプルの最後に「シーケンスの終了」トークンを追加するだけで十分です。

    次のように SentencePieceProcessor の周囲にカスタム ラッパーを作成します。

class GemmaTokenizer:

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad ID."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    The tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an "end of sentence" token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """A TensorFlow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

新しいカスタム GemmaTokenizer をインスタンス化し、MTNT データセットの小規模なサンプルに適用してみましょう。

tokenizer = GemmaTokenizer(vocab)

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  prefix='Translate this into French:\n',
                                  suffix='\n',
                                  add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  add_eos=True)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
                       'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]

MTNT データセット全体のデータローダーを構築します。

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to the model.
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'

class MTNTDatasetBuilder:
  """The dataset builder for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692,
             DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(example,
                                          prefix=self.TRANSLATION_PREFIX,
                                          suffix=self.TRANSLATION_SUFFIX,
                                          add_eos=False)

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example,
                                          add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(input_tensor,
                  [[0, to_pad]],
                  mode='CONSTANT',
                  constant_values=pad_value,
                  )

  def _to_training_input(self,
                         src_tokens: jax.Array,
                         dst_tokens: jax.Array,
                         ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample.
    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
                                                             self._tokenize_destination(x['dst'])))

    # Convert the samples to training inputs.
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples that are too long.
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset.
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary.
    ds = ds.repeat(num_epochs)

    # Build batches.
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and no repetition.
    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
                                                                  self._tokenize_destination(x['dst'])))
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

MTNTDatasetBuilder を試してみましょう。カスタム GemmaTokenizer を再度インスタンス化してから MTNT データセットに適用し、2 つの例をサンプリングします。

tokenizer = GemmaTokenizer(vocab)

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108  10924    665  12302
  235341    108      2   4397  63011   1437  38696   1241      1      0]
 [     2  49688    736   1280   6987 235292    108  13835   1517 235265
     108      2  69875    540  19713 235265      1      0      0      0]
 [     2  49688    736   1280   6987 235292    108   6956   1586 235297
  235265    108      2  78368   1586 235297 235265      1      0      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True False]
 [False False False False False False False False False False False  True
   True  True  True  True  True False False False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True False False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108  18874 235341    108
       2 115905   6425   1241      1      0      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   7574   3356 235341
     108      2   7997  20707   1241      1      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   8703    665 235265
     108      2 235338 235303  90006  20133 235265      1      0      0]]
target_mask: [[False False False False False False False False False False  True  True
   True  True  True False False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True  True  True False False]]

モデルを構成する

Gemma モデルのファインチューニングを開始する前に、Gemma モデルを構成する必要があります。

まず、gemma.params.load_and_format_params メソッドを使用して、Gemma モデル チェックポイントを読み込み、フォーマットします。

params = params_lib.load_and_format_params(CKPT_PATH)

Gemma モデル チェックポイントから正しい構成を自動的に読み込むには、gemma.transformer.TransformerConfig を使用します。cache_size 引数は、Gemma Transformer キャッシュ内のタイムステップの数です。その後、gemma.transformer.Transformerflax.linen.Module から継承)を使用して、Gemma モデルを model_2b としてインスタンス化します。

config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30
)

model_2b = transformer_lib.Transformer(config=config_2b)

モデルをファインチューニングする

このセクションでは、次の作業を行います。

  • gemma.transformer.Transformer クラスを使用して、フォワード パス / 損失関数を作成します。
  • トークンの位置とアテンションのマスクベクトルを作成する
  • Flax を使用してトレーニング ステップ関数を作成する。
  • バックワード パスを使用せずに検証ステップを作成します。
  • トレーニング ループを作成する。
  • Gemma モデルをファインチューニングします。

gemma.transformer.Transformer クラスを使用して、フォワードパスと損失関数を定義します。Gemma Transformerflax.linen.Module を継承し、次の 2 つの重要なメソッドを備えています。

  • init: モデルのパラメータを初期化します。
  • apply: 指定されたパラメータ セットを使用して、モデルの __call__ 関数を実行します。

    トレーニング済みの Gemma 重みを扱っているため、init 関数を使用する必要はありません。

def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """The forward pass and the loss function.

  Args:
    params: Model's input parameters.
    model: The Gemma transformer model to call.
    input_tokens: Input tokens sequence, shape [B, L].
    input_mask: Tokens to ignore when computing the loss, shape [B, L].
    positions: Relative position of each token, shape [B, L].
    attention_mask: Input attention mask, shape [B, L].

  Returns:
    The softmax cross-entropy loss for the next-token prediction task.
  """

  # The forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicted.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels to one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Define the normalization factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the negative log likelihood (NLL) loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

gemma.transformer.Transformer クラスには、各入力とともに attention_mask ベクトルと positions ベクトルが必要です。これらを生成するには、Transformer.build_positions_from_maskTransformer.make_causal_attn_mask を使用するカスタム関数を作成します。

def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

バックワード パスを実行し、それに応じてモデルのパラメータを更新する train_step 関数を作成します。ここで、

  • jax.value_and_grad は、フォワードパスとバックワード パス時の損失関数と勾配を評価するために使用されます。
  • optax.apply_updates は、パラメータの更新に使用します。
def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example: TrainingInput):
  """Train step.

  Args:
    model: The Gemma transformer model.
    params: The model's input parameters.
    optimizer: The Optax optimizer to use.
    opt_state: The input optimizer's state.
    pad_id: ID of the pad token.
    example: Input batch.

  Returns:
    The training loss, the updated parameters, and the updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)

  # The forward and backward passes.
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example.input_tokens,
                                                             input_mask=example.target_mask,
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters.
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

バックワード パスなしで validation_step 関数を作成します。

def validation_step(model: transformer_lib.Transformer,
                    params,
                    pad_id: int,
                    example: TrainingInput,
                    ):
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
  val_loss = forward_and_loss_fn(params,
                                 model=model,
                                 input_tokens=example.input_tokens,
                                 input_mask=example.target_mask,
                                 positions=positions,
                                 attention_mask=attention_mask)
  return val_loss

optax.sgd を使用して SGD オプティマイザーのトレーニング ループを定義します。

@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

def train_loop(
    model: transformer_lib.Transformer,
    params,
    dataset_builder: MTNTDatasetBuilder,
    training_cfg: TrainingConfig):

  # Apply `jax.jit` on the training step, making the whole loop much more efficient.
  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])

  # Apply `jax.jit` on the validation step.
  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])

  # To save memory, use the SGD optimizer instead of the usual Adam optimizer.
  # Note that for this specific example, SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  # Build the training dataset.
  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
                                               num_epochs=training_cfg.num_epochs)
  train_ds = train_ds.as_numpy_iterator()

  # Build the validation dataset, with a limited number of samples for this demo.
  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss=0

  # A first round of the validation loss.
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += compiled_validation_step(model,
                                          params,
                                          dataset_builder._tokenizer.pad_id,
                                          val_example)
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = compiled_train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=dataset_builder._tokenizer.pad_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += compiled_validation_step(model,
                                              params,
                                              dataset_builder._tokenizer.pad_id,
                                              val_example)
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

メモリに収まるように、限られたステップ数(SEQ_SIZE)で Gemma モデルの微調整を開始します。

SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=100)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    dataset_builder=dataset_builder,
                    training_cfg=training_cfg)
Start, validation loss: 10.647212982177734
STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336
STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848
STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459
STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975
STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245

トレーニングの損失と検証の損失の両方が、ステップ数ごとに減少しているはずです。

gemma.sampler.Samplersampler を作成します。Gemma モデル チェックポイントとトークナイザを使用します。

sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['params'],
)

sampler を使用して、モデルが翻訳できるかどうかを確認します。gemma.sampler.Samplertotal_generation_steps 引数は、レスポンスの生成時に実行されるステップの数です。入力をトレーニング形式に一致させるには、接頭辞 Translate this into French:\n の末尾に改行文字を追加します。これにより、モデルに翻訳の開始を指示できます。

sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=100,
    ).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]

詳細

  • Google DeepMind の gemma ライブラリの詳細を確認できます。これには、このチュートリアルで使用したモジュールの docstring(gemma.paramsgemma.transformergemma.sampler など)が含まれています。
  • ライブラリに独自のドキュメント サイト(core JAXFlaxChexOptaxOrbax)があります。
  • sentencepiece トークナイザ/デトークナイザのドキュメントについては、Google の sentencepiece GitHub リポジトリをご覧ください。
  • kagglehub のドキュメントについては、Kaggle の kagglehub GitHub リポジトリREADME.md をご覧ください。
  • Google Cloud Vertex AI で Gemma モデルを使用する方法を学習する。
  • Google Cloud TPU(v3-8 以降)を使用している場合は、最新の jax[tpu] パッケージ(!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html)に更新し、ランタイムを再起動して、jaxjaxlib のバージョンが一致している(!pip list | grep jax)ことを確認します。これにより、jaxlibjax のバージョンの不一致が原因で発生する RuntimeError を防ぐことができます。JAX の詳細なインストール手順については、JAX のドキュメントをご覧ください。