Tinh chỉnh RecurrentGemma bằng JAX và Flax

Xem trên ai.google.dev Chạy trong Google Colab Mở trong Vertex AI Xem nguồn trên GitHub

Hướng dẫn này minh hoạ cách tinh chỉnh mô hình Hướng dẫn RecurrentGemma 2B cho nhiệm vụ dịch tiếng Anh-Pháp bằng thư viện recurrentgemma của Google DeepMind, JAX (thư viện điện toán số hiệu suất cao), Flax (thư viện mạng nơron dựa trên JAX), Chex (thư viện tiện ích để viết mã JAX và xử lý JAX không đáng tin cậy Mặc dù Flax không được sử dụng trực tiếp trong sổ tay này, nhưng Flax đã được dùng để tạo Gemma.

Thư viện recurrentgemma được viết bằng JAX, Flax, Orbax (một thư viện dựa trên JAX để huấn luyện các tiện ích như kiểm tra điểm kiểm tra) và SentencePiece (thư viện tokenizer/detokenizer).

Sổ tay này có thể chạy trên Google Colab với GPU T4 (chuyển đến phần Chỉnh sửa > Cài đặt sổ tay > Trong phần Trình tăng tốc phần cứng, hãy chọn GPU T4).

Thiết lập

Các phần sau đây giải thích các bước chuẩn bị để sử dụng mô hình RecurrentGemma cho sổ tay, bao gồm cả quyền truy cập vào mô hình, lấy khoá API và định cấu hình thời gian chạy của sổ tay.

Thiết lập quyền truy cập vào Kaggle cho Gemma

Để hoàn tất hướng dẫn này, trước tiên, bạn cần làm theo các hướng dẫn thiết lập tương tự như thiết lập Gemma với một số ngoại lệ:

  • Truy cập vào RecurrentGemma (thay vì Gemma) trên kaggle.com.
  • Chọn một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình RecurrentGemma.
  • Tạo và định cấu hình tên người dùng Kaggle và khoá API.

Sau khi hoàn tất quy trình thiết lập RecurrentGemma, hãy chuyển sang phần tiếp theo để thiết lập các biến môi trường cho môi trường Colab của bạn.

Đặt các biến môi trường

Thiết lập các biến môi trường cho KAGGLE_USERNAMEKAGGLE_KEY. Khi được nhắc với thông báo "Cấp quyền truy cập?", hãy đồng ý cung cấp quyền truy cập bí mật.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Cài đặt thư viện recurrentgemma

Tính năng tăng tốc phần cứng miễn phí của Colab hiện không đủ để chạy sổ tay này. Nếu bạn đang sử dụng Colab Pay As You Go hoặc Colab Pro, hãy nhấp vào Chỉnh sửa > Cài đặt sổ tay > Chọn GPU A100 > Lưu để bật tính năng tăng tốc phần cứng.

Tiếp theo, bạn cần cài đặt thư viện Google DeepMind recurrentgemma từ github.com/google-deepmind/recurrentgemma. Nếu gặp lỗi "trình phân giải phần phụ thuộc của pip", bạn thường có thể bỏ qua lỗi đó.

pip install -q git+https://github.com/google-deepmind/recurrentgemma.git

Nhập thư viện

Sổ tay này sử dụng Flax (dành cho mạng nơron), lõi JAX, SentencePiece (để mã hoá), Chex (thư viện tiện ích để viết mã JAX đáng tin cậy), Optax (thư viện xử lý và tối ưu hoá độ dốc) và Tập dữ liệu TensorFlow.

import pathlib
from typing import Any, Mapping, Iterator
import enum
import functools

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

import sentencepiece as spm

from recurrentgemma import jax as recurrentgemma

Tải mô hình RecurrentGemma

  1. Tải mô hình RecurrentGemma bằng kagglehub.model_download. Thao tác này sẽ nhận 3 đối số:
  • handle: Tên người dùng mô hình trong Kaggle
  • path: (Chuỗi không bắt buộc) Đường dẫn cục bộ
  • force_download: (Boolean không bắt buộc) Buộc tải lại mô hình xuống
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:50<00:00, 81.5MB/s]
Extracting model files...
print('RECURRENTGEMMA_VARIANT:', RECURRENTGEMMA_VARIANT)
RECURRENTGEMMA_VARIANT: 2b-it
  1. Kiểm tra vị trí của trọng số mô hình và trình tạo mã thông báo, sau đó đặt các biến đường dẫn. Thư mục trình tạo mã thông báo sẽ nằm trong thư mục chính mà bạn đã tải mô hình xuống, còn trọng số của mô hình sẽ nằm trong thư mục con. Ví dụ:
  • Tệp tokenizer.model sẽ nằm trong /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • Điểm kiểm tra mô hình sẽ nằm trong /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Tải và chuẩn bị tập dữ liệu MTNT và trình tạo mã thông báo Gemma

Bạn sẽ dùng tập dữ liệu MTNT (Bản dịch máy của văn bản nhiễu) có sẵn trong Tập dữ liệu TensorFlow.

Tải xuống phần tập dữ liệu từ tiếng Anh sang tiếng Pháp của tập dữ liệu MTNT, sau đó lấy mẫu hai ví dụ. Mỗi mẫu trong tập dữ liệu chứa hai mục nhập: src: câu tiếng Anh gốc; và dst: bản dịch tiếng Pháp tương ứng.

ds = tfds.load("mtnt/en-fr", split="train")

ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/35692 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-train.tfrecord*...:   0%|          …
Generating test examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-test.tfrecord*...:   0%|          |…
Generating valid examples...:   0%|          | 0/811 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-valid.tfrecord*...:   0%|          …
Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.
Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'

Tải trình tạo mã thông báo Gemma, được tạo bằng sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

Tuỳ chỉnh SentencePieceProcessor cho nhiệm vụ dịch từ tiếng Anh sang tiếng Pháp. Do bạn sẽ tinh chỉnh phần tiếng Anh của mô hình RecurrentGemma (Griffin), bạn cần thực hiện một vài điều chỉnh như:

  • Tiền tố dữ liệu đầu vào: Việc thêm một tiền tố chung vào mỗi dữ liệu đầu vào sẽ báo hiệu cho nhiệm vụ dịch. Ví dụ: bạn có thể sử dụng câu lệnh có tiền tố như Translate this into French: [INPUT_SENTENCE].

  • Hậu tố bắt đầu dịch: Việc thêm hậu tố vào cuối mỗi câu lệnh sẽ hướng dẫn mô hình Gemma chính xác thời điểm bắt đầu quá trình dịch. Một dòng mới sẽ thực hiện công việc.

  • Mã thông báo mô hình ngôn ngữ: Các mô hình RecurrentGemma (Griffin) yêu cầu có mã thông báo "bắt đầu trình tự" ở đầu mỗi trình tự. Tương tự, bạn cần thêm mã thông báo "kết thúc trình tự" ở cuối mỗi ví dụ huấn luyện.

Hãy tạo một trình bao bọc tuỳ chỉnh xung quanh SentencePieceProcessor như sau:

class GriffinTokenizer:
  """A custom wrapper around a SentencePieceProcessor."""

  def __init__(self, spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad ID."""
    return self._spm_processor.pad_id()

  def tokenize(
      self,
      example: str | bytes,
      prefix: str = '',
      suffix: str = '',
      add_eos: bool = True,
  ) -> jax.Array:
    """
    A tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an end of sentence token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(
      self,
      str_tensor: tf.Tensor,
      prefix: str = '',
      suffix: str = '',
      add_eos: bool = True,
  ) -> tf.Tensor:
    """A TensforFlow operator for the `tokenize` function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

Hãy dùng thử bằng cách tạo thực thể cho GriffinTokenizer tuỳ chỉnh mới của bạn, sau đó áp dụng phương thức này trên một mẫu nhỏ của tập dữ liệu MTNT:

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(
      example,
      prefix='Translate this into French:\n',
      suffix='\n',
      add_eos=False
  )
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example, add_eos=True)

tokenizer = GriffinTokenizer(vocab)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {
    'src': tokenize_source(tokenizer, x['src']),
    'dst': tokenize_destination(tokenizer, x['dst'])
  })
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]

Tạo một trình tải dữ liệu cho toàn bộ tập dữ liệu MTNT:

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to the model.
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'


class MTNTDatasetBuilder:
  """A data loader for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GriffinTokenizer,
               max_seq_len: int):
    """A constructor.

    Args:
      tokenizer: The tokenizer to use.
      max_seq_len: The size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """A tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(
        example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
        add_eos=False
    )

  def _tokenize_destination(self, example: tf.Tensor):
    """A tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example, add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(
        input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
    )

  def _to_training_input(
      self,
      src_tokens: jax.Array,
      dst_tokens: jax.Array,
  ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # You want to prevent the model from updating based on the source (input)
    # tokens. To achieve this, add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # You don't want to perform the backward on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample.
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )

    # Convert them to training inputs.
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples which are too long.
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset.
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary.
    ds = ds.repeat(num_epochs)

    # Build batches.
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same as the training dataset, but no shuffling and no repetition
    ds = self._base_data[DatasetSplit.VALIDATION].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

Hãy thử MTNTDatasetBuilder bằng cách tạo thực thể cho GriffinTokenizer tuỳ chỉnh một lần nữa, sau đó áp dụng phương thức này trên tập dữ liệu MTNT và lấy mẫu 2 ví dụ:

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108  12583    665 235265
     108      2   6151  94975   1320   6238 235265      1      0      0]
 [     2  49688    736   1280   6987 235292    108   4899  29960  11270
  108282 235265    108      2   4899  79025  11270 108282      1      0]
 [     2  49688    736   1280   6987 235292    108  26620 235265    108
       2  26620 235265      1      0      0      0      0      0      0]]
target_mask: [[False False False False False False False False False False False  True
   True  True  True  True  True  True False False]
 [False False False False False False False False False False False False
  False  True  True  True  True  True  True False]
 [False False False False False False False False False False  True  True
   True  True False False False False False False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108    527   5174   1683
  235336    108      2 206790    581  20726    482   2208   1654      1]
 [     2  49688    736   1280   6987 235292    108  28484 235256 235336
     108      2 120500  13832   1654      1      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108 235324 235304   2705
  235265    108      2 235324 235304  19963 235265      1      0      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True  True]
 [False False False False False False False False False False False  True
   True  True  True  True False False False False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True False False]]

Định cấu hình mô hình

Trước khi bắt đầu tinh chỉnh mô hình Gemma, bạn cần định cấu hình mô hình đó.

Tải điểm kiểm tra của mô hình RecurrentGemma (Griffin) bằng phương thức recurrentgemma.jax.utils.load_parameters:

params =  recurrentgemma.load_parameters(CKPT_PATH, "single_device")

Để tự động tải cấu hình chính xác từ điểm kiểm tra mô hình RecurrentGemma, hãy sử dụng recurrentgemma.GriffinConfig.from_flax_params_or_variables:

config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)

Tạo thực thể cho mô hình Griffin bằng recurrentgemma.jax.Griffin:

model = recurrentgemma.Griffin(config)

Tạo một samplerrecurrentgemma.jax.Sampler ở đầu điểm kiểm tra/trọng số mô hình RecurrentGemma và trình tạo mã thông báo để kiểm tra xem mô hình của bạn có thể thực hiện việc dịch mã hay không:

sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)

Tinh chỉnh mô hình

Trong phần này, bạn sẽ:

  • Dùng lớp gemma.transformer.Transformer để tạo hàm chuyển và hàm mất dữ liệu chuyển tiếp.
  • Tạo vectơ vị trí và mặt nạ chú ý cho mã thông báo
  • Tạo một hàm bước huấn luyện bằng Flax.
  • Tạo bước xác thực mà không bị truyền ngược.
  • Tạo vòng lặp huấn luyện.
  • Tinh chỉnh mô hình Gemma.

Xác định lượt chuyển tiếp và hàm mất quyền bằng cách sử dụng lớp recurrentgemma.jax.griffin.Griffin. RecurrentGemma Griffin kế thừa từ flax.linen.Module và cung cấp 2 phương thức thiết yếu:

  • init: Khởi chạy các tham số của mô hình.
  • apply: Thực thi hàm __call__ của mô hình bằng cách sử dụng một nhóm tham số nhất định.

Vì bạn đang làm việc với các trọng số Gemma đã huấn luyện trước, nên bạn không cần phải sử dụng hàm init.

def forward_and_loss_fn(
    params,
    *,
    model: recurrentgemma.Griffin,
    input_tokens: jax.Array,            # Shape [B, L]
    input_mask: jax.Array,              # Shape [B, L]
    positions: jax.Array,               # Shape [B, L]
) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: Griffin model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """
  batch_size = input_tokens.shape[0]
  # Forward pass on the input data.
  # No attention cache is needed here.
  # Exclude the last step as it does not appear in the targets.
  logits, _ = model.apply(
        {"params": params},
        tokens=input_tokens[:, :-1],
        segment_pos=positions[:, :-1],
        cache=None,
    )

  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[:, 1:]
  target_mask = input_mask[:, 1:]

  # Convert the target labels into one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Normalization factor.
  norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)

  # Return the negative log-likelihood loss (NLL) function.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor

Tạo hàm train_step thực hiện truyền ngược và cập nhật các tham số của mô hình cho phù hợp, trong đó:

Params = Mapping[str, Any]

def get_positions(example: jax.Array, pad_id : int) -> jax.Array:
  """Builds the position vector from the given tokens."""
  pad_mask = example != pad_id
  positions = jnp.cumsum(pad_mask, axis=-1)
  # Subtract one for all positions from the first valid one as they are
  # 0-indexed
  positions = positions - (positions >= 1)
  return positions

@functools.partial(
    jax.jit,
    static_argnames=['model', 'optimizer'],
    donate_argnames=['params', 'opt_state'],
)
def train_step(
    model: recurrentgemma.Griffin,
    params: Params,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
    pad_id: int,
    example: TrainingInput,
) -> tuple[jax.Array, Params, optax.OptState]:
  """The train step.

  Args:
    model: The RecurrentGemma (Griffin) model.
    params: The model's input parameters.
    optimizer: The Optax optimizer to use.
    opt_state: The input optimizer's state.
    pad_id: The ID of the pad token.
    example: The input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """

  positions = get_positions(example.input_tokens, pad_id)

  # Forward and backward passes.
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
      params,
      model=model,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      positions=positions,
  )
  # Update the parameters.
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

Tạo hàm validation_step mà không có truyền ngược:

@functools.partial(jax.jit, static_argnames=['model'])
def validation_step(
    model: recurrentgemma.Griffin,
    params: Params,
    pad_id: int,
    example: TrainingInput,
) -> jax.Array:
  return forward_and_loss_fn(
      params,
      model=model,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      positions=get_positions(example.input_tokens, pad_id),
  )

Xác định vòng lặp huấn luyện:

def train_loop(
    model: recurrentgemma.Griffin,
    params: Params,
    optimizer: optax.GradientTransformation,
    train_ds: Iterator[TrainingInput],
    validation_ds: Iterator[TrainingInput],
    num_steps: int | None = None,
    eval_every_n: int = 20,
):
  opt_state = jax.jit(optimizer.init)(params)

  step_counter = 0
  avg_loss=0

  # The first round of the validation loss.
  n_steps_eval = 0
  eval_loss = 0
  for val_example in validation_ds.as_numpy_iterator():
    eval_loss += validation_step(
        model, params, dataset_builder._tokenizer.pad_id, val_example
    )
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = train_step(
        model=model,
        params=params,
        optimizer=optimizer,
        opt_state=opt_state,
        pad_id=dataset_builder._tokenizer.pad_id,
        example=train_example,
    )

    step_counter += 1
    avg_loss += train_loss
    if step_counter % eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += validation_step(
            model,
            params,
            dataset_builder._tokenizer.pad_id,
            val_example,
        )
        n_steps_eval +=1
      avg_loss /= eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if num_steps is not None and step_counter > num_steps:
      break
  return params

Tại đây, bạn phải chọn một trình tối ưu hoá (Optax). Đối với các thiết bị có bộ nhớ nhỏ hơn, bạn nên sử dụng SGD vì SGD có mức sử dụng bộ nhớ thấp hơn nhiều. Để đạt được hiệu suất tinh chỉnh tốt nhất, hãy thử Adam-W. Siêu tham số tối ưu cho mỗi trình tối ưu hoá đối với tác vụ cụ thể trong sổ tay này được cung cấp trong ví dụ này cho điểm kiểm tra 2b-it.

def griffin_weight_decay_mask(params_like: optax.Params) -> Any:
  # Don't put weight decay on the RGLRU, the embeddings and any biases
  def enable_weight_decay(path: list[Any], _: Any) -> bool:
    # Parameters in the LRU and embedder
    path = [dict_key.key for dict_key in path]
    if 'rg_lru' in path or 'embedder' in path:
      return False
    # All biases and scales
    if path[-1] in ('b', 'scale'):
      return False
    return True

  return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)

optimizer_choice = "sgd"

if optimizer_choice == "sgd":
  optimizer = optax.sgd(learning_rate=1e-3)
  num_steps = 300
elif optimizer_choice == "adamw":
  optimizer = optax.adamw(
        learning_rate=1e-4,
        b2=0.96,
        eps=1e-8,
        weight_decay=0.1,
        mask=griffin_weight_decay_mask,
    )
  num_steps = 100
else:
  raise ValueError(f"Unknown optimizer: {optimizer_choice}")

Chuẩn bị tập dữ liệu huấn luyện và xác thực:

# Choose a small sequence length size, so that everything fits in memory.
num_epochs = 1
batch_size = 1
sequence_length = 32

# Make the dataset builder.
tokenizer = GriffinTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)

# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(
    batch_size=batch_size,
    num_epochs=num_epochs,
).as_numpy_iterator()

# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(
    batch_size=batch_size,
).take(50)

Bắt đầu tinh chỉnh mô hình RecurrentGemma (Griffin) với một số bước giới hạn (num_steps):

trained_params = train_loop(
    model=model,
    params=params,
    optimizer=optimizer,
    train_ds=train_ds,
    validation_ds=validation_ds,
    num_steps=num_steps,
)
Start, validation loss: 7.894117832183838
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,33]), ShapedArray(bool[1,33]), ShapedArray(int32[], weak_type=True).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
STEP 20 training loss: 4.592616081237793 - eval loss: 2.847407102584839
STEP 40 training loss: 2.7537424564361572 - eval loss: 2.9258534908294678
STEP 60 training loss: 2.835618257522583 - eval loss: 2.4382340908050537
STEP 80 training loss: 2.6322107315063477 - eval loss: 2.3696839809417725
STEP 100 training loss: 1.8703256845474243 - eval loss: 2.355681896209717
STEP 120 training loss: 2.7280433177948 - eval loss: 2.4059958457946777
STEP 140 training loss: 2.3047447204589844 - eval loss: 2.083082914352417
STEP 160 training loss: 2.3432137966156006 - eval loss: 2.095074415206909
STEP 180 training loss: 2.1081202030181885 - eval loss: 2.006460189819336
STEP 200 training loss: 2.5359647274017334 - eval loss: 1.9667452573776245
STEP 220 training loss: 2.202195644378662 - eval loss: 1.9440618753433228
STEP 240 training loss: 2.756615400314331 - eval loss: 2.1073737144470215
STEP 260 training loss: 2.5128934383392334 - eval loss: 2.117241859436035
STEP 280 training loss: 2.73045015335083 - eval loss: 1.9159646034240723
STEP 300 training loss: 2.0918595790863037 - eval loss: 1.9742532968521118

Cả mất dữ liệu huấn luyện và mất xác thực đều sẽ giảm theo số bước.

Để đảm bảo dữ liệu đầu vào của bạn phù hợp với định dạng huấn luyện, hãy nhớ sử dụng tiền tố Translate this into French:\n và ký tự dòng mới ở cuối. Việc này sẽ báo hiệu cho mô hình để bắt đầu dịch.

sampler.params = trained_params
output = sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=100,
)
print(output.text[0])
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,16]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Mais je m'appelle Morgane.

Tìm hiểu thêm