使用 JAX 和 Flax 对 Gemma 进行微调

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Vertex AI 中打开 查看 GitHub 上的源代码

概览

Gemma 是以 Google DeepMind Gemini 研究和技术为基础打造的轻量级、先进的开放大语言模型系列。本教程演示了如何使用 Google DeepMind 的 gemmaJAX(高性能数值计算库)、Flax(基于 JAX 的神经网络库)、Chex(用于编写可靠的 JAX 代码、Notax1 优化和 MT1 优化实用程序库)(用于编写可靠的 JAX 代码、Notax1 和 MT1 优化和 MT1 优化实用程序库)为英语-法语翻译任务微调 Gemma 2B Instruct 模型。虽然此笔记本中未直接使用 Flax,但用于创建 Gemma 的 Flax。

gemma 库使用 JAX、Flax、Orbax(一种基于 JAX 的库,用于检查点等训练实用程序)和 SentencePiece(标记生成器/去标记生成器库)编写。

初始设置

1. 为 Gemma 设置 Kaggle 访问权限

要完成本教程,您首先需要按照 Gemma 设置中的设置说明进行操作,其中说明了如何执行以下操作:

  • 通过 kaggle.com 访问 Gemma。
  • 选择具有足够资源的 Colab 运行时来运行 Gemma 模型。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续下一部分,您将为 Colab 环境设置环境变量。

2. 设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。当系统提示“授予访问权限吗?”消息时,同意提供密钥访问权限。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. 安装 gemma

免费的 Colab 硬件加速目前不足,无法运行此笔记本。如果您使用的是 Colab 随用随付或 Colab Pro,请依次点击修改 > 笔记本设置 > 选择 A100 GPU > 保存,以启用硬件加速。

接下来,您需要从 github.com/google-deepmind/gemma 安装 Google DeepMind gemma 库。如果收到有关“pip 的依赖项解析器”的错误,通常可以忽略。

pip install -q git+https://github.com/google-deepmind/gemma.git

4. 导入库

此笔记本使用 Flax(用于神经网络)、核心 JAXSentencePiece(用于标记化)、C 十六进制(用于编写可靠 JAX 代码的实用程序库)和 TensorFlow 数据集。

import os
import enum
import re
import string

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

加载 Gemma 模型

使用 kagglehub.model_download 加载 Gemma 模型,该模型采用三个参数:

  • handle:Kaggle 中的模型句柄
  • path:(可选字符串)本地路径
  • force_download:(可选布尔值)强制重新下载模型
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2

检查模型权重和标记生成器的位置,然后设置路径变量。标记生成器目录位于您下载模型的主目录中,而模型权重将位于一个子目录中。例如:

  • tokenizer.model 文件将位于 /LOCAL/PATH/TO/gemma/flax/2b-it/2 中。
  • 模型检查点位于 /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it 中)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

加载并准备 MTNT 数据集和 Gemma 分词器

您将使用 MTNT(噪声文本的机器翻译)数据集,该数据集可在 TensorFlow 数据集中找到。

下载 MTNT 数据集的英语到法语形式的数据集部分,然后对两个示例进行采样。数据集中的每个样本包含两个条目:src:原始英语句子;dst:相应的法语翻译。

ds = tfds.load("mtnt/en-fr", split="train")

ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/35692 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...:   0%|          …
Generating test examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...:   0%|          |…
Generating valid examples...:   0%|          | 0/811 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...:   0%|          …
Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.
Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'

加载使用 sentencepiece.SentencePieceProcessor 构建的 Gemma 分词器:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

为英法互译任务自定义 SentencePieceProcessor。由于您要微调 Gemma 模型的英语部分,因此需要进行一些调整,例如:

  • 输入前缀:为每个输入添加公共前缀,表示翻译任务。例如,您可以使用带有 Translate this into French: [INPUT_SENTENCE] 前缀的提示。

  • 翻译起始后缀:在每个提示的末尾添加后缀,用于指示 Gemma 模型确切地何时开始翻译。只需另起一行即可。

  • 语言模型词元:Gemma 模型要求在每个序列的开头有一个“序列开头”词元,因此在每个训练示例的末尾添加“序列结束”词元应该就足够了。

    围绕 SentencePieceProcessor 构建自定义封装容器,如下所示:

class GemmaTokenizer:

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad ID."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    The tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an "end of sentence" token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """A TensorFlow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

不妨尝试一下,将您的新自定义 GemmaTokenizer 实例化,然后应用到 MTNT 数据集的一个小样本上:

tokenizer = GemmaTokenizer(vocab)

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  prefix='Translate this into French:\n',
                                  suffix='\n',
                                  add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  add_eos=True)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
                       'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]

为整个 MTNT 数据集构建数据加载器:

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to the model.
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'

class MTNTDatasetBuilder:
  """The dataset builder for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692,
             DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(example,
                                          prefix=self.TRANSLATION_PREFIX,
                                          suffix=self.TRANSLATION_SUFFIX,
                                          add_eos=False)

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example,
                                          add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(input_tensor,
                  [[0, to_pad]],
                  mode='CONSTANT',
                  constant_values=pad_value,
                  )

  def _to_training_input(self,
                         src_tokens: jax.Array,
                         dst_tokens: jax.Array,
                         ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample.
    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
                                                             self._tokenize_destination(x['dst'])))

    # Convert the samples to training inputs.
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples that are too long.
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset.
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary.
    ds = ds.repeat(num_epochs)

    # Build batches.
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and no repetition.
    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
                                                                  self._tokenize_destination(x['dst'])))
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

如需试用 MTNTDatasetBuilder,请再次实例化自定义 GemmaTokenizer,然后将其应用于 MTNT 数据集,并对两个示例进行采样:

tokenizer = GemmaTokenizer(vocab)

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108  10924    665  12302
  235341    108      2   4397  63011   1437  38696   1241      1      0]
 [     2  49688    736   1280   6987 235292    108  13835   1517 235265
     108      2  69875    540  19713 235265      1      0      0      0]
 [     2  49688    736   1280   6987 235292    108   6956   1586 235297
  235265    108      2  78368   1586 235297 235265      1      0      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True False]
 [False False False False False False False False False False False  True
   True  True  True  True  True False False False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True False False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108  18874 235341    108
       2 115905   6425   1241      1      0      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   7574   3356 235341
     108      2   7997  20707   1241      1      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   8703    665 235265
     108      2 235338 235303  90006  20133 235265      1      0      0]]
target_mask: [[False False False False False False False False False False  True  True
   True  True  True False False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True  True  True False False]]

配置模型

在开始对 Gemma 模型进行微调之前,您需要对其进行配置。

首先,使用 gemma.params.load_and_format_params 方法加载 Gemma 模型检查点并设置其格式:

params = params_lib.load_and_format_params(CKPT_PATH)

如需从 Gemma 模型检查点自动加载正确的配置,请使用 gemma.transformer.TransformerConfigcache_size 参数是 Gemma Transformer 缓存中时间步数。之后,使用 gemma.transformer.Transformer(继承自 flax.linen.Module)将 Gemma 模型实例化为 model_2b

config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30
)

model_2b = transformer_lib.Transformer(config=config_2b)

微调模型

在此部分中,您将完成以下各项:

  • 使用 gemma.transformer.Transformer 类创建前向传播和损失函数。
  • 为词元构建位置和注意力掩码向量
  • 使用 Flax 构建训练步函数。
  • 构建验证步骤而不采用反向传递。
  • 创建训练循环。
  • 微调 Gemma 模型。

使用 gemma.transformer.Transformer 类定义前向传播和损失函数。Gemma Transformer 继承自 flax.linen.Module,并提供了两种基本方法:

  • init:初始化模型的参数。
  • apply:使用一组给定的参数执行模型的 __call__ 函数。

    由于您使用的是预训练 Gemma 权重,因此无需使用 init 函数。

def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """The forward pass and the loss function.

  Args:
    params: Model's input parameters.
    model: The Gemma transformer model to call.
    input_tokens: Input tokens sequence, shape [B, L].
    input_mask: Tokens to ignore when computing the loss, shape [B, L].
    positions: Relative position of each token, shape [B, L].
    attention_mask: Input attention mask, shape [B, L].

  Returns:
    The softmax cross-entropy loss for the next-token prediction task.
  """

  # The forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicted.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels to one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Define the normalization factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the negative log likelihood (NLL) loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

gemma.transformer.Transformer 类要求每个输入旁有一个 attention_maskpositions 向量。您可以通过创建一个使用 Transformer.build_positions_from_maskTransformer.make_causal_attn_mask 的自定义函数来生成这些类:

def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

构建执行反向传递并相应地更新模型参数的 train_step 函数,其中:

def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example: TrainingInput):
  """Train step.

  Args:
    model: The Gemma transformer model.
    params: The model's input parameters.
    optimizer: The Optax optimizer to use.
    opt_state: The input optimizer's state.
    pad_id: ID of the pad token.
    example: Input batch.

  Returns:
    The training loss, the updated parameters, and the updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)

  # The forward and backward passes.
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example.input_tokens,
                                                             input_mask=example.target_mask,
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters.
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

构建不使用反向传递的 validation_step 函数:

def validation_step(model: transformer_lib.Transformer,
                    params,
                    pad_id: int,
                    example: TrainingInput,
                    ):
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
  val_loss = forward_and_loss_fn(params,
                                 model=model,
                                 input_tokens=example.input_tokens,
                                 input_mask=example.target_mask,
                                 positions=positions,
                                 attention_mask=attention_mask)
  return val_loss

使用 optax.sgd 为 SGD 优化器定义训练循环:

@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

def train_loop(
    model: transformer_lib.Transformer,
    params,
    dataset_builder: MTNTDatasetBuilder,
    training_cfg: TrainingConfig):

  # Apply `jax.jit` on the training step, making the whole loop much more efficient.
  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])

  # Apply `jax.jit` on the validation step.
  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])

  # To save memory, use the SGD optimizer instead of the usual Adam optimizer.
  # Note that for this specific example, SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  # Build the training dataset.
  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
                                               num_epochs=training_cfg.num_epochs)
  train_ds = train_ds.as_numpy_iterator()

  # Build the validation dataset, with a limited number of samples for this demo.
  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss=0

  # A first round of the validation loss.
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += compiled_validation_step(model,
                                          params,
                                          dataset_builder._tokenizer.pad_id,
                                          val_example)
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = compiled_train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=dataset_builder._tokenizer.pad_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += compiled_validation_step(model,
                                              params,
                                              dataset_builder._tokenizer.pad_id,
                                              val_example)
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

开始基于有限数量的步骤 (SEQ_SIZE) 微调 Gemma 模型,以确保其适合内存:

SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=100)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    dataset_builder=dataset_builder,
                    training_cfg=training_cfg)
Start, validation loss: 10.647212982177734
STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336
STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848
STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459
STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975
STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245

每增加步数,训练损失和验证损失都应该会下降。

使用 gemma.sampler.Sampler 创建 sampler。它使用 Gemma 模型检查点和标记生成器。

sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['params'],
)

使用 sampler 检查模型是否可以执行翻译。gemma.sampler.Sampler 中的 total_generation_steps 参数是生成响应时执行的步数。为了确保输入内容与训练格式匹配,请使用前缀 Translate this into French:\n,并在末尾带换行符。这会指示模型开始翻译。

sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=100,
    ).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]

了解详情