使用 JAX 和 Flax 微調 Gemma

前往 ai.google.dev 查看 在 Google Colab 中執行 在 Vertex AI 中開啟 在 GitHub 上查看原始碼

總覽

Gemma 是一系列最先進的輕量開放大型語言模型系列,以 Google DeepMind Gemini 的研究和技術為基礎。本教學課程說明如何使用 Google DeepMind 的 gemma 程式庫JAX (高效能數值運算程式庫)、Flax (以 JAX 為基礎的類神經網路程式庫的 JAX 梯度程式庫),以及 JAX 梯度的公用程式庫 (JAX 的梯度處理程式庫),再為英文法文翻譯工作微調 Gemma 2B Instruct 模型。CTL1(以 JAX 為基礎的梯度最佳化資料庫) 和 JAX 程式碼最佳化程式庫。雖然這個筆記本並未直接使用 Flax,但它是用來建立 Gemma。

gemma 程式庫是使用 JAX、Flax、Orbax (以 JAX 為基礎的程式庫,用於檢查點等訓練公用程式的程式庫) 和 SentencePiece (符記化工具/detokenizer 程式庫) 編寫而成。

設定

1. 設定 Gemma 的 Kaggle 存取權

如要完成本教學課程,您必須先按照 Gemma 設定中的操作說明進行設定,瞭解如何執行下列操作:

  • 前往 kaggle.com 存取 Gemma。
  • 請選取具備足夠資源來執行 Gemma 模型的 Colab 執行階段。
  • 產生並設定 Kaggle 使用者名稱與 API 金鑰。

完成 Gemma 設定後,請前往下一節,為 Colab 環境設定環境變數。

2. 設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。出現「授予存取權?」訊息時,請同意提供密鑰存取權。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. 安裝 gemma 程式庫

免費的 Colab 硬體加速功能目前不足,無法執行這個筆記本。如果您使用的是 Colab Pay As You Go 或 Colab Pro,請依序點選「編輯」 >「筆記本設定」 >選取「A100 GPU」 >「儲存」,啟用硬體加速功能。

接下來,你必須從 github.com/google-deepmind/gemma 安裝 Google DeepMind gemma 程式庫。如果系統顯示「pip 依附元件解析器」的錯誤訊息,通常可以予以忽略。

pip install -q git+https://github.com/google-deepmind/gemma.git

4. 匯入程式庫

這個筆記本使用 Flax (適用於類神經網路)、核心 JAXSentencePiece (用於權杖化)、C 十六進位 (編寫可靠 JAX 程式碼的公用程式程式庫) 和 TensorFlow Datasets。

import os
import enum
import re
import string

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

載入 Gemma 模型

使用 kagglehub.model_download 載入 Gemma 模型,該模型會採用三個引數:

  • handle:Kaggle 的模型控制代碼
  • path:(選用字串) 本機路徑
  • force_download:(選用布林值) 強制重新下載模型
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2

請檢查模型權重和權杖化工具的位置,然後設定路徑變數。符記化工具目錄位於您下載模型的主目錄中,模型權重則位於子目錄中。例如:

  • tokenizer.model 檔案會位於 /LOCAL/PATH/TO/gemma/flax/2b-it/2)。
  • 模型查核點會顯示在 /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it 中)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

載入並準備 MTNT 資料集和 Gemma 權杖化工具

您將使用 TensorFlow Datasets 提供的 MTNT (Machine Translation of Noisy Text) 資料集。

下載 MTNT 資料集的英文到法文資料集部分,然後取樣兩個範例。資料集中的每個樣本都包含兩個項目:src:原始英文句子,以及 dst:對應的法文翻譯。

ds = tfds.load("mtnt/en-fr", split="train")

ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/35692 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...:   0%|          …
Generating test examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...:   0%|          |…
Generating valid examples...:   0%|          | 0/811 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...:   0%|          …
Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.
Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'

載入使用 sentencepiece.SentencePieceProcessor 建構的 Gemma 權杖化工具:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

針對英文到法文的翻譯工作自訂 SentencePieceProcessor。由於您將微調 Gemma 模型的英文部分,因此必須進行幾項調整,例如:

  • 輸入內容前置字串:為每個輸入輸入共同的前置字串,表示翻譯工作。舉例來說,您可以使用含有前置字串 (例如 Translate this into French: [INPUT_SENTENCE]) 的提示。

  • 翻譯開始後置字元:在每個提示結尾加上後置字串,即可指示 Gemma 模型確切開始翻譯程序的時間。工作應用新的文字行內容。

  • 語言模型符記:Gemma 模型預期每個序列開頭都會有「序列開頭」符記,因此只要在每個訓練範例的結尾加上「序列結尾」符記即可。

    SentencePieceProcessor 周圍建構自訂包裝函式,如下所示:

class GemmaTokenizer:

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad ID."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    The tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an "end of sentence" token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """A TensorFlow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

方法是將新的自訂 GemmaTokenizer 執行個體化,然後套用至一小部分的 MTNT 資料集範例:

tokenizer = GemmaTokenizer(vocab)

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  prefix='Translate this into French:\n',
                                  suffix='\n',
                                  add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  add_eos=True)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
                       'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]

建構整個 MTNT 資料集的資料載入器:

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to the model.
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'

class MTNTDatasetBuilder:
  """The dataset builder for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692,
             DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(example,
                                          prefix=self.TRANSLATION_PREFIX,
                                          suffix=self.TRANSLATION_SUFFIX,
                                          add_eos=False)

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example,
                                          add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(input_tensor,
                  [[0, to_pad]],
                  mode='CONSTANT',
                  constant_values=pad_value,
                  )

  def _to_training_input(self,
                         src_tokens: jax.Array,
                         dst_tokens: jax.Array,
                         ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample.
    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
                                                             self._tokenize_destination(x['dst'])))

    # Convert the samples to training inputs.
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples that are too long.
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset.
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary.
    ds = ds.repeat(num_epochs)

    # Build batches.
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and no repetition.
    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
                                                                  self._tokenize_destination(x['dst'])))
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

再次執行MTNTDatasetBuilder,將自訂 GemmaTokenizer執行個體化,接著套用至 MTNT 資料集,並對兩個範例進行取樣:

tokenizer = GemmaTokenizer(vocab)

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108  10924    665  12302
  235341    108      2   4397  63011   1437  38696   1241      1      0]
 [     2  49688    736   1280   6987 235292    108  13835   1517 235265
     108      2  69875    540  19713 235265      1      0      0      0]
 [     2  49688    736   1280   6987 235292    108   6956   1586 235297
  235265    108      2  78368   1586 235297 235265      1      0      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True False]
 [False False False False False False False False False False False  True
   True  True  True  True  True False False False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True False False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108  18874 235341    108
       2 115905   6425   1241      1      0      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   7574   3356 235341
     108      2   7997  20707   1241      1      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   8703    665 235265
     108      2 235338 235303  90006  20133 235265      1      0      0]]
target_mask: [[False False False False False False False False False False  True  True
   True  True  True False False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True  True  True False False]]

設定模型

開始微調 Gemma 模型前,需要先設定模型。

首先,使用 gemma.params.load_and_format_params 方法載入 Gemma 模型查核點並設定格式:

params = params_lib.load_and_format_params(CKPT_PATH)

如要從 Gemma 模型查核點自動載入正確的設定,請使用 gemma.transformer.TransformerConfigcache_size 引數是 Gemma Transformer 快取中的步數。之後,請使用 gemma.transformer.Transformer (繼承自 flax.linen.Module) 將 Gemma 模型例項化為 model_2b

config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30
)

model_2b = transformer_lib.Transformer(config=config_2b)

微調模型

在本節中,您將進行以下作業:

  • 使用 gemma.transformer.Transformer 類別建立正向傳遞和損失函式。
  • 建立符記的位置和注意力遮罩向量
  • 使用 Flax 建立訓練步數函式。
  • 建構驗證步驟,但不通過反向測試。
  • 建立訓練迴圈。
  • 微調 Gemma 模型。

使用 gemma.transformer.Transformer 類別定義正向傳遞和損失函式。Gemma Transformer 繼承自 flax.linen.Module,並提供兩個重要方法:

  • init:初始化模型的參數。
  • apply:使用一組指定的參數執行模型的 __call__ 函式。

    由於您使用預先訓練的 Gemma 權重,因此不需要使用 init 函式。

def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """The forward pass and the loss function.

  Args:
    params: Model's input parameters.
    model: The Gemma transformer model to call.
    input_tokens: Input tokens sequence, shape [B, L].
    input_mask: Tokens to ignore when computing the loss, shape [B, L].
    positions: Relative position of each token, shape [B, L].
    attention_mask: Input attention mask, shape [B, L].

  Returns:
    The softmax cross-entropy loss for the next-token prediction task.
  """

  # The forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicted.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels to one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Define the normalization factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the negative log likelihood (NLL) loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

gemma.transformer.Transformer 類別需要 attention_maskpositions 向量,以及每項輸入內容。如要產生這些內容,您可以使用 Transformer.build_positions_from_maskTransformer.make_causal_attn_mask 建立自訂函式:

def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

建立 train_step 函式,用於執行向後傳遞,並據此更新模型的參數,其中:

def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example: TrainingInput):
  """Train step.

  Args:
    model: The Gemma transformer model.
    params: The model's input parameters.
    optimizer: The Optax optimizer to use.
    opt_state: The input optimizer's state.
    pad_id: ID of the pad token.
    example: Input batch.

  Returns:
    The training loss, the updated parameters, and the updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)

  # The forward and backward passes.
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example.input_tokens,
                                                             input_mask=example.target_mask,
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters.
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

建構 validation_step 函式,但不包括向後傳遞:

def validation_step(model: transformer_lib.Transformer,
                    params,
                    pad_id: int,
                    example: TrainingInput,
                    ):
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
  val_loss = forward_and_loss_fn(params,
                                 model=model,
                                 input_tokens=example.input_tokens,
                                 input_mask=example.target_mask,
                                 positions=positions,
                                 attention_mask=attention_mask)
  return val_loss

使用 optax.sgd 為 SGD 最佳化工具定義訓練迴圈:

@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

def train_loop(
    model: transformer_lib.Transformer,
    params,
    dataset_builder: MTNTDatasetBuilder,
    training_cfg: TrainingConfig):

  # Apply `jax.jit` on the training step, making the whole loop much more efficient.
  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])

  # Apply `jax.jit` on the validation step.
  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])

  # To save memory, use the SGD optimizer instead of the usual Adam optimizer.
  # Note that for this specific example, SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  # Build the training dataset.
  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
                                               num_epochs=training_cfg.num_epochs)
  train_ds = train_ds.as_numpy_iterator()

  # Build the validation dataset, with a limited number of samples for this demo.
  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss=0

  # A first round of the validation loss.
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += compiled_validation_step(model,
                                          params,
                                          dataset_builder._tokenizer.pad_id,
                                          val_example)
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = compiled_train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=dataset_builder._tokenizer.pad_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += compiled_validation_step(model,
                                              params,
                                              dataset_builder._tokenizer.pad_id,
                                              val_example)
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

透過少數步驟 (SEQ_SIZE) 開始微調 Gemma 模型,以確保記憶體符合記憶體大小:

SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=100)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    dataset_builder=dataset_builder,
                    training_cfg=training_cfg)
Start, validation loss: 10.647212982177734
STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336
STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848
STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459
STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975
STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245

每個步數的訓練損失和驗證損失都應減少。

使用 gemma.sampler.Sampler 建立 sampler。這項工具會使用 Gemma 模型查核點和權杖化工具。

sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['params'],
)

使用 sampler 檢查模型是否能執行翻譯。gemma.sampler.Sampler 中的 total_generation_steps 引數是指產生回應時執行的步驟數。為確保輸入內容與訓練格式相符,請使用前置字串 Translate this into French:\n 並在結尾加上換行字元。這會讓模型開始翻譯。

sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=100,
    ).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]

瞭解詳情