تنظیم دقیق RecurrentGemma با استفاده از JAX و Flax

مشاهده در ai.google.dev در Google Colab اجرا شود در Vertex AI باز کنید مشاهده منبع در GitHub

این آموزش نحوه تنظیم دقیق مدل RecurrentGemma 2B Instruct را برای کار ترجمه انگلیسی به فرانسوی با استفاده از کتابخانه recurrentgemma Google DeepMind ، JAX (یک کتابخانه محاسباتی عددی با کارایی بالا)، Flax (کتابخانه شبکه عصبی مبتنی بر JAX)، Chex ( کتابخانه ای از ابزارهای کاربردی برای نوشتن کد قابل اعتماد JAX)، Optax (کتابخانه پردازش گرادیان و بهینه سازی مبتنی بر JAX) و مجموعه داده MTNT (ترجمه ماشینی متن پر سر و صدا) . اگرچه از Flax مستقیماً در این نوت بوک استفاده نمی شود، از Flax برای ایجاد Gemma استفاده شده است.

کتابخانه recurrentgemma با JAX، Flax، Orbax (یک کتابخانه مبتنی بر JAX برای ابزارهای آموزشی مانند checkpointing) و SentencePiece (یک کتابخانه توکنایزر/دوکنیزر) نوشته شده است.

این نوت بوک می تواند در Google Colab با پردازنده گرافیکی T4 اجرا شود (به Edit > تنظیمات نوت بوک > زیر شتاب دهنده سخت افزار، T4 GPU را انتخاب کنید).

برپایی

بخش‌های زیر مراحل آماده‌سازی یک نوت‌بوک برای استفاده از مدل RecurrentGemma، از جمله دسترسی به مدل، دریافت کلید API و پیکربندی زمان اجرا نوت‌بوک را توضیح می‌دهند.

دسترسی Kaggle را برای Gemma تنظیم کنید

برای تکمیل این آموزش، ابتدا باید دستورالعمل های راه اندازی مشابه راه اندازی Gemma را با چند استثنا دنبال کنید:

  • در kaggle.com به RecurrentGemma (به جای Gemma) دسترسی پیدا کنید.
  • یک زمان اجرا Colab با منابع کافی برای اجرای مدل RecurrentGemma انتخاب کنید.
  • نام کاربری و کلید API Kaggle را ایجاد و پیکربندی کنید.

پس از تکمیل تنظیمات RecurrentGemma، به بخش بعدی بروید، جایی که متغیرهای محیطی را برای محیط Colab خود تنظیم خواهید کرد.

تنظیم متغیرهای محیطی

متغیرهای محیطی را برای KAGGLE_USERNAME و KAGGLE_KEY تنظیم کنید. وقتی از شما خواسته شد "دسترسی بدهید؟" پیام‌ها، با ارائه دسترسی مخفی موافقت کنید.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

کتابخانه recurrentgemma را نصب کنید

شتاب سخت‌افزار Free Colab در حال حاضر برای اجرای این نوت بوک کافی نیست . اگر از Colab Pay As You Go یا Colab Pro استفاده می‌کنید، روی Edit > تنظیمات نوت‌بوک > انتخاب A100 GPU > Save کلیک کنید تا شتاب سخت‌افزاری فعال شود.

در مرحله بعد، باید کتابخانه Google DeepMind recurrentgemma را از github.com/google-deepmind/recurrentgemma نصب کنید. اگر خطای «تحلیل کننده وابستگی پیپ» دریافت کردید، معمولاً می توانید آن را نادیده بگیرید.

pip install -q git+https://github.com/google-deepmind/recurrentgemma.git

واردات کتابخانه ها

این نوت‌بوک از Flax (برای شبکه‌های عصبی)، هسته JAX ، SentencePiece (برای توکن‌سازی)، Chex (کتابخانه‌ای از ابزارهای کاربردی برای نوشتن کد قابل اعتماد JAX)، Optax (کتابخانه پردازش گرادیان و بهینه‌سازی) و TensorFlow Datasets استفاده می‌کند.

import pathlib
from typing import Any, Mapping, Iterator
import enum
import functools

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

import sentencepiece as spm

from recurrentgemma import jax as recurrentgemma

مدل RecurrentGemma را بارگیری کنید

  1. مدل RecurrentGemma را با kagglehub.model_download بارگیری کنید که سه آرگومان می گیرد:
  • handle : مدل دسته از Kaggle
  • path : (رشته اختیاری) مسیر محلی
  • force_download : (بولی اختیاری) بارگیری مجدد مدل را مجبور می کند
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:50<00:00, 81.5MB/s]
Extracting model files...
print('RECURRENTGEMMA_VARIANT:', RECURRENTGEMMA_VARIANT)
RECURRENTGEMMA_VARIANT: 2b-it
  1. محل وزن های مدل و توکنایزر را بررسی کنید، سپس متغیرهای مسیر را تنظیم کنید. دایرکتوری توکنایزر در دایرکتوری اصلی جایی که مدل را دانلود کرده‌اید قرار می‌گیرد، در حالی که وزن‌های مدل در یک دایرکتوری فرعی قرار دارند. مثلا:
  • فایل tokenizer.model در /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 خواهد بود.
  • نقطه بازرسی مدل در /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it خواهد بود.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

مجموعه داده MTNT و توکنایزر Gemma را بارگیری و آماده کنید

شما از مجموعه داده MTNT (ترجمه ماشینی متن پر سر و صدا) استفاده خواهید کرد که از TensorFlow Datasets در دسترس است.

بخش داده انگلیسی به فرانسوی مجموعه داده MTNT را دانلود کنید و سپس دو نمونه را نمونه کنید. هر نمونه در مجموعه داده شامل دو ورودی است: src : جمله اصلی انگلیسی. و dst : ترجمه فرانسوی مربوطه.

ds = tfds.load("mtnt/en-fr", split="train")

ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/35692 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-train.tfrecord*...:   0%|          …
Generating test examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-test.tfrecord*...:   0%|          |…
Generating valid examples...:   0%|          | 0/811 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-valid.tfrecord*...:   0%|          …
Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.
Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'

توکنایزر Gemma را بارگیری کنید که با استفاده از sentencepiece.SentencePieceProcessor ساخته شده است.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

SentencePieceProcessor برای کار ترجمه انگلیسی به فرانسوی سفارشی کنید. از آنجایی که بخش انگلیسی مدل RecurrentGemma (Griffin) را به دقت تنظیم خواهید کرد، باید تنظیماتی را انجام دهید، مانند:

  • پیشوند ورودی : افزودن یک پیشوند مشترک به هر ورودی سیگنال وظیفه ترجمه را می دهد. برای مثال، می‌توانید از یک درخواست با پیشوندی مانند Translate this into French: [INPUT_SENTENCE] .

  • پسوند شروع ترجمه : اضافه کردن یک پسوند در انتهای هر فرمان به مدل Gemma دستور می دهد که دقیقا چه زمانی فرآیند ترجمه را آغاز کند. یک خط جدید باید این کار را انجام دهد.

  • نشانه‌های مدل زبان : مدل‌های RecurrentGemma (Griffin) در ابتدای هر دنباله، توکن «شروع دنباله‌ای» را انتظار دارند. به طور مشابه، باید در پایان هر مثال آموزشی یک نشانه "پایان دنباله" اضافه کنید.

یک بسته بندی سفارشی در اطراف SentencePieceProcessor به شرح زیر بسازید:

class GriffinTokenizer:
  """A custom wrapper around a SentencePieceProcessor."""

  def __init__(self, spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad ID."""
    return self._spm_processor.pad_id()

  def tokenize(
      self,
      example: str | bytes,
      prefix: str = '',
      suffix: str = '',
      add_eos: bool = True,
  ) -> jax.Array:
    """
    A tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an end of sentence token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(
      self,
      str_tensor: tf.Tensor,
      prefix: str = '',
      suffix: str = '',
      add_eos: bool = True,
  ) -> tf.Tensor:
    """A TensforFlow operator for the `tokenize` function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

آن را با نمونه سازی GriffinTokenizer سفارشی جدید خود، و سپس اعمال آن بر روی نمونه کوچکی از مجموعه داده MTNT امتحان کنید:

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(
      example,
      prefix='Translate this into French:\n',
      suffix='\n',
      add_eos=False
  )
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example, add_eos=True)

tokenizer = GriffinTokenizer(vocab)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {
    'src': tokenize_source(tokenizer, x['src']),
    'dst': tokenize_destination(tokenizer, x['dst'])
  })
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]

یک بارگذار داده برای کل مجموعه داده MTNT بسازید:

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to the model.
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'


class MTNTDatasetBuilder:
  """A data loader for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GriffinTokenizer,
               max_seq_len: int):
    """A constructor.

    Args:
      tokenizer: The tokenizer to use.
      max_seq_len: The size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """A tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(
        example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
        add_eos=False
    )

  def _tokenize_destination(self, example: tf.Tensor):
    """A tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example, add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(
        input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
    )

  def _to_training_input(
      self,
      src_tokens: jax.Array,
      dst_tokens: jax.Array,
  ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # You want to prevent the model from updating based on the source (input)
    # tokens. To achieve this, add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # You don't want to perform the backward on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample.
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )

    # Convert them to training inputs.
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples which are too long.
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset.
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary.
    ds = ds.repeat(num_epochs)

    # Build batches.
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same as the training dataset, but no shuffling and no repetition
    ds = self._base_data[DatasetSplit.VALIDATION].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

MTNTDatasetBuilder را با نمونه سازی مجدد GriffinTokenizer سفارشی، سپس اعمال آن بر روی مجموعه داده MTNT و نمونه برداری از دو نمونه امتحان کنید:

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108  12583    665 235265
     108      2   6151  94975   1320   6238 235265      1      0      0]
 [     2  49688    736   1280   6987 235292    108   4899  29960  11270
  108282 235265    108      2   4899  79025  11270 108282      1      0]
 [     2  49688    736   1280   6987 235292    108  26620 235265    108
       2  26620 235265      1      0      0      0      0      0      0]]
target_mask: [[False False False False False False False False False False False  True
   True  True  True  True  True  True False False]
 [False False False False False False False False False False False False
  False  True  True  True  True  True  True False]
 [False False False False False False False False False False  True  True
   True  True False False False False False False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108    527   5174   1683
  235336    108      2 206790    581  20726    482   2208   1654      1]
 [     2  49688    736   1280   6987 235292    108  28484 235256 235336
     108      2 120500  13832   1654      1      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108 235324 235304   2705
  235265    108      2 235324 235304  19963 235265      1      0      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True  True]
 [False False False False False False False False False False False  True
   True  True  True  True False False False False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True False False]]

مدل را پیکربندی کنید

قبل از شروع به تنظیم دقیق مدل Gemma، باید آن را پیکربندی کنید.

بازرسی مدل RecurrentGemma (Griffin) را با روش recurrentgemma.jax.utils.load_parameters بارگیری کنید:

params =  recurrentgemma.load_parameters(CKPT_PATH, "single_device")

برای بارگیری خودکار پیکربندی صحیح از نقطه بازرسی مدل RecurrentGemma، از recurrentgemma.GriffinConfig.from_flax_params_or_variables استفاده کنید:

config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)

مدل گریفین را با recurrentgemma.jax.Griffin نمونه سازی کنید:

model = recurrentgemma.Griffin(config)

یک sampler با recurrentgemma.jax.Sampler در بالای نقطه/وزن‌های مدل RecurrentGemma و توکنایزر ایجاد کنید تا بررسی کنید آیا مدل شما می‌تواند ترجمه را انجام دهد:

sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)

مدل را دقیق تنظیم کنید

در این بخش، شما:

  • از کلاس gemma.transformer.Transformer برای ایجاد تابع گذر و ضرر به جلو استفاده کنید.
  • بردارهای ماسک موقعیت و توجه را برای توکن ها بسازید
  • یک تابع مرحله آموزشی با Flax بسازید.
  • مرحله اعتبار سنجی را بدون پاس رو به عقب بسازید.
  • حلقه آموزش را ایجاد کنید.
  • مدل جما را دقیق تنظیم کنید.

با استفاده از کلاس recurrentgemma.jax.griffin.Griffin ، پاس رو به جلو و تابع ضرر را تعریف کنید. RecurrentGemma Griffin از flax.linen.Module به ارث می رسد و دو روش ضروری را ارائه می دهد:

  • init : پارامترهای مدل را مقدار دهی اولیه می کند.
  • apply : تابع __call__ مدل را با استفاده از مجموعه ای از پارامترها اجرا می کند.

از آنجایی که شما با وزنه های Gemma از قبل آموزش داده شده کار می کنید، نیازی به استفاده از تابع init ندارید.

def forward_and_loss_fn(
    params,
    *,
    model: recurrentgemma.Griffin,
    input_tokens: jax.Array,            # Shape [B, L]
    input_mask: jax.Array,              # Shape [B, L]
    positions: jax.Array,               # Shape [B, L]
) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: Griffin model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """
  batch_size = input_tokens.shape[0]
  # Forward pass on the input data.
  # No attention cache is needed here.
  # Exclude the last step as it does not appear in the targets.
  logits, _ = model.apply(
        {"params": params},
        tokens=input_tokens[:, :-1],
        segment_pos=positions[:, :-1],
        cache=None,
    )

  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[:, 1:]
  target_mask = input_mask[:, 1:]

  # Convert the target labels into one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Normalization factor.
  norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)

  # Return the negative log-likelihood loss (NLL) function.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor

تابع train_step را بسازید که گذر به عقب را انجام می دهد و پارامترهای مدل را بر اساس آن به روز می کند، در جایی که:

  • jax.value_and_grad برای ارزیابی تابع از دست دادن و گرادیان در طول پاس های جلو و عقب است.
  • optax.apply_updates برای به روز رسانی پارامترها است.
Params = Mapping[str, Any]

def get_positions(example: jax.Array, pad_id : int) -> jax.Array:
  """Builds the position vector from the given tokens."""
  pad_mask = example != pad_id
  positions = jnp.cumsum(pad_mask, axis=-1)
  # Subtract one for all positions from the first valid one as they are
  # 0-indexed
  positions = positions - (positions >= 1)
  return positions

@functools.partial(
    jax.jit,
    static_argnames=['model', 'optimizer'],
    donate_argnames=['params', 'opt_state'],
)
def train_step(
    model: recurrentgemma.Griffin,
    params: Params,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
    pad_id: int,
    example: TrainingInput,
) -> tuple[jax.Array, Params, optax.OptState]:
  """The train step.

  Args:
    model: The RecurrentGemma (Griffin) model.
    params: The model's input parameters.
    optimizer: The Optax optimizer to use.
    opt_state: The input optimizer's state.
    pad_id: The ID of the pad token.
    example: The input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """

  positions = get_positions(example.input_tokens, pad_id)

  # Forward and backward passes.
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
      params,
      model=model,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      positions=positions,
  )
  # Update the parameters.
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

تابع validation_step را بدون عبور برگشت بسازید:

@functools.partial(jax.jit, static_argnames=['model'])
def validation_step(
    model: recurrentgemma.Griffin,
    params: Params,
    pad_id: int,
    example: TrainingInput,
) -> jax.Array:
  return forward_and_loss_fn(
      params,
      model=model,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      positions=get_positions(example.input_tokens, pad_id),
  )

تعریف حلقه آموزشی:

def train_loop(
    model: recurrentgemma.Griffin,
    params: Params,
    optimizer: optax.GradientTransformation,
    train_ds: Iterator[TrainingInput],
    validation_ds: Iterator[TrainingInput],
    num_steps: int | None = None,
    eval_every_n: int = 20,
):
  opt_state = jax.jit(optimizer.init)(params)

  step_counter = 0
  avg_loss=0

  # The first round of the validation loss.
  n_steps_eval = 0
  eval_loss = 0
  for val_example in validation_ds.as_numpy_iterator():
    eval_loss += validation_step(
        model, params, dataset_builder._tokenizer.pad_id, val_example
    )
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = train_step(
        model=model,
        params=params,
        optimizer=optimizer,
        opt_state=opt_state,
        pad_id=dataset_builder._tokenizer.pad_id,
        example=train_example,
    )

    step_counter += 1
    avg_loss += train_loss
    if step_counter % eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += validation_step(
            model,
            params,
            dataset_builder._tokenizer.pad_id,
            val_example,
        )
        n_steps_eval +=1
      avg_loss /= eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if num_steps is not None and step_counter > num_steps:
      break
  return params

در اینجا شما باید یک بهینه ساز (Optax) را انتخاب کنید. برای دستگاه‌هایی که حافظه کمتری دارند، باید از SGD استفاده کنید، زیرا حافظه بسیار کمتری دارد. برای دستیابی به بهترین عملکرد تنظیم دقیق، Adam-W را امتحان کنید. فراپارامترهای بهینه برای هر بهینه ساز برای کار خاص در این نوت بوک در این مثال برای نقطه بازرسی 2b-it ارائه شده است.

def griffin_weight_decay_mask(params_like: optax.Params) -> Any:
  # Don't put weight decay on the RGLRU, the embeddings and any biases
  def enable_weight_decay(path: list[Any], _: Any) -> bool:
    # Parameters in the LRU and embedder
    path = [dict_key.key for dict_key in path]
    if 'rg_lru' in path or 'embedder' in path:
      return False
    # All biases and scales
    if path[-1] in ('b', 'scale'):
      return False
    return True

  return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)

optimizer_choice = "sgd"

if optimizer_choice == "sgd":
  optimizer = optax.sgd(learning_rate=1e-3)
  num_steps = 300
elif optimizer_choice == "adamw":
  optimizer = optax.adamw(
        learning_rate=1e-4,
        b2=0.96,
        eps=1e-8,
        weight_decay=0.1,
        mask=griffin_weight_decay_mask,
    )
  num_steps = 100
else:
  raise ValueError(f"Unknown optimizer: {optimizer_choice}")

مجموعه داده های آموزشی و اعتبار سنجی را آماده کنید:

# Choose a small sequence length size, so that everything fits in memory.
num_epochs = 1
batch_size = 1
sequence_length = 32

# Make the dataset builder.
tokenizer = GriffinTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)

# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(
    batch_size=batch_size,
    num_epochs=num_epochs,
).as_numpy_iterator()

# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(
    batch_size=batch_size,
).take(50)

شروع به تنظیم دقیق مدل RecurrentGemma (Griffin) در تعداد محدودی از مراحل ( num_steps ):

trained_params = train_loop(
    model=model,
    params=params,
    optimizer=optimizer,
    train_ds=train_ds,
    validation_ds=validation_ds,
    num_steps=num_steps,
)
Start, validation loss: 7.894117832183838
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,33]), ShapedArray(bool[1,33]), ShapedArray(int32[], weak_type=True).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
STEP 20 training loss: 4.592616081237793 - eval loss: 2.847407102584839
STEP 40 training loss: 2.7537424564361572 - eval loss: 2.9258534908294678
STEP 60 training loss: 2.835618257522583 - eval loss: 2.4382340908050537
STEP 80 training loss: 2.6322107315063477 - eval loss: 2.3696839809417725
STEP 100 training loss: 1.8703256845474243 - eval loss: 2.355681896209717
STEP 120 training loss: 2.7280433177948 - eval loss: 2.4059958457946777
STEP 140 training loss: 2.3047447204589844 - eval loss: 2.083082914352417
STEP 160 training loss: 2.3432137966156006 - eval loss: 2.095074415206909
STEP 180 training loss: 2.1081202030181885 - eval loss: 2.006460189819336
STEP 200 training loss: 2.5359647274017334 - eval loss: 1.9667452573776245
STEP 220 training loss: 2.202195644378662 - eval loss: 1.9440618753433228
STEP 240 training loss: 2.756615400314331 - eval loss: 2.1073737144470215
STEP 260 training loss: 2.5128934383392334 - eval loss: 2.117241859436035
STEP 280 training loss: 2.73045015335083 - eval loss: 1.9159646034240723
STEP 300 training loss: 2.0918595790863037 - eval loss: 1.9742532968521118

هم ضرر آموزش و هم ضرر اعتبار باید با شمارش هر مرحله کاهش می یافت.

برای اطمینان از مطابقت ورودی شما با قالب آموزشی، به یاد داشته باشید که از پیشوند Translate this into French:\n و یک کاراکتر خط جدید در پایان استفاده کنید. این به مدل سیگنال می دهد که ترجمه را آغاز کند.

sampler.params = trained_params
output = sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=100,
)
print(output.text[0])
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,16]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Mais je m'appelle Morgane.

بیشتر بدانید