ضبط Gemma باستخدام JAX وFlax

العرض على ai.google.dev التنفيذ في Google Colab الفتح في Vertex AI عرض المصدر على GitHub

نظرة عامة

"جيما" هي عائلة تتألّف من نماذج لغوية كبيرة وحديثة وبسيطة ومفتوحة استنادًا إلى أبحاث وتقنيات Google DeepMind Gemini. يوضح هذا البرنامج التعليمي كيفية تحسين نموذج Gemma 2B Instruct لمهمة الترجمة من الإنجليزية والفرنسية باستخدام مكتبة gemma من Google DeepMind وJAX (مكتبة حوسبة رقمية عالية الأداء) وFlax (مكتبة الشبكة العصبية المستندة إلى JAX) وChex (مكتبة من الأدوات المساعدة لكتابة رمز JAX المعروف بترميز JAX) وOptax على الرغم من عدم استخدام Flax مباشرةً في هذا الدفتر، فقد تم استخدامه لإنشاء Gemma.

تمت كتابة مكتبة gemma باستخدام JAX وFlax وOrbax (مكتبة مستندة إلى JAX لأدوات التدريب مثل ميزة "فحص نقطة") وSentencePiece (مكتبة أداة إنشاء الرموز المميّزة أو أداة إزالة الرموز المميّزة).

ضبط إعدادات الجهاز

1. إعداد وصول Kaggle لـ Gemma

لإكمال هذا البرنامج التعليمي، عليك أولاً اتّباع تعليمات الإعداد في إعداد Gemma، والتي توضِّح لك كيفية إجراء ما يلي:

  • يمكنك الوصول إلى Gemma على kaggle.com.
  • اختَر بيئة تشغيل Colab بها موارد كافية لتشغيل نموذج Gemma.
  • إنشاء وتكوين اسم مستخدم ومفتاح واجهة برمجة تطبيقات Kaggle.

بعد الانتهاء من إعداد Gemma، انتقِل إلى القسم التالي، حيث يمكنك ضبط متغيّرات البيئة لبيئة Colab.

2. ضبط متغيرات البيئة

ضبط متغيّرات البيئة لكل من KAGGLE_USERNAME وKAGGLE_KEY عندما يُطلب منك عرض رسالة "هل تريد منح إمكانية الوصول؟" الموافقة على توفير الوصول السري.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3- تثبيت مكتبة gemma

إنّ ميزة تسريع أجهزة Colab المجانية غير كافية حاليًا لتشغيل ورقة الملاحظات هذه. في حال استخدام Colab Pay As You Go أو Colab Pro، انقر على تعديل > إعدادات ورقة الملاحظات > اختَر وحدة معالجة رسومات A100 > انقر على حفظ لتفعيل ميزة "تسريع الأجهزة".

بعد ذلك، عليك تثبيت مكتبة Google DeepMind gemma من github.com/google-deepmind/gemma. إذا ظهرت لك رسالة خطأ بشأن "أداة حل التبعية في pip"، يمكنك عادةً تجاهلها.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. استيراد المكتبات

ويستخدم ورقة الملاحظات هذه برامج Flax (للشبكات العصبية) وJAX وSentencePiece الأساسية (لترميز الرموز المميزة) وChex (مكتبة من الأدوات المساعدة لكتابة رمز JAX موثوق به) ومجموعات بيانات TensorFlow.

import os
import enum
import re
import string

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

تحميل نموذج Gemma

حمِّل نموذج Gemma باستخدام kagglehub.model_download الذي يأخذ ثلاث وسيطات:

  • handle: مؤشر النموذج من Kaggle
  • path: (سلسلة اختيارية) المسار المحلي
  • force_download: (قيمة منطقية اختيارية) تفرض إعادة تنزيل النموذج
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2

تحقَّق من الموقع الجغرافي لقيم ترجيح النموذج وأداة إنشاء الرموز المميّزة، ثم اضبط متغيّرات المسار. سيكون دليل أداة إنشاء الرموز المميّزة في الدليل الرئيسي الذي نزّلت فيه النموذج، بينما ستكون القيم التقديرية للنموذج في دليل فرعي. على سبيل المثال:

  • سيكون الملف tokenizer.model باللغة /LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • ستكون النقطة المرجعية للنموذج في /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it).
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

حمِّل إعداد مجموعة بيانات MTNT ومزيل الرمز المميّز لـ Gemma

ستستخدم مجموعة بيانات MTNT (الترجمة الآلية للنص المزعج)، المتاحة من مجموعات بيانات TensorFlow.

قم بتنزيل جزء مجموعة البيانات من الإنجليزية إلى الفرنسية من مجموعة بيانات MTNT، ثم عينة من مثالين. تحتوي كل عيّنة في مجموعة البيانات على إدخالَين: src: الجملة الإنجليزية الأصلية. وdst: الترجمة الفرنسية المقابلة.

ds = tfds.load("mtnt/en-fr", split="train")

ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/35692 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...:   0%|          …
Generating test examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...:   0%|          |…
Generating valid examples...:   0%|          | 0/811 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...:   0%|          …
Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.
Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'

تحميل برنامج ترميز Gemma المميز الذي تم إنشاؤه باستخدام sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

خصِّص SentencePieceProcessor لمهمة الترجمة من الإنجليزية إلى الفرنسية. نظرًا لأنك ستضبط الجزء الإنجليزي من نموذج Gemma، فستحتاج إلى إجراء بعض التعديلات، مثل:

  • بادئة الإدخال: تؤدي إضافة بادئة مشتركة إلى كل إدخال إلى إرسال مهمة الترجمة. على سبيل المثال، يمكنك استخدام طلب يتضمّن بادئة مثل Translate this into French: [INPUT_SENTENCE].

  • لاحقة بدء الترجمة: من خلال إضافة لاحقة في نهاية كل طلب، يتم توجيه نموذج Gemma إلى الوقت المحدّد لبدء عملية الترجمة. يجب أن يؤدي سطر جديد هذه المهمة.

  • الرموز المميزة لنموذج اللغة: تتوقّع نماذج Gemma "بداية التسلسل" في بداية كل تسلسل، لذا فإن إضافة "نهاية التسلسل" في نهاية كل مثال تدريبي كافيًا.

    يمكنك إنشاء برنامج تضمين مخصّص حول SentencePieceProcessor على النحو التالي:

class GemmaTokenizer:

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad ID."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    The tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an "end of sentence" token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """A TensorFlow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

جرِّب هذه الميزة من خلال إنشاء مثيل لـ GemmaTokenizer المخصّص الجديد، ثم تطبيقه على عيّنة صغيرة من مجموعة بيانات MTNT:

tokenizer = GemmaTokenizer(vocab)

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  prefix='Translate this into French:\n',
                                  suffix='\n',
                                  add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  add_eos=True)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
                       'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]

إنشاء أداة تحميل بيانات لمجموعة بيانات MTNT بأكملها:

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to the model.
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'

class MTNTDatasetBuilder:
  """The dataset builder for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692,
             DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(example,
                                          prefix=self.TRANSLATION_PREFIX,
                                          suffix=self.TRANSLATION_SUFFIX,
                                          add_eos=False)

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example,
                                          add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(input_tensor,
                  [[0, to_pad]],
                  mode='CONSTANT',
                  constant_values=pad_value,
                  )

  def _to_training_input(self,
                         src_tokens: jax.Array,
                         dst_tokens: jax.Array,
                         ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample.
    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
                                                             self._tokenize_destination(x['dst'])))

    # Convert the samples to training inputs.
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples that are too long.
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset.
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary.
    ds = ds.repeat(num_epochs)

    # Build batches.
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and no repetition.
    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
                                                                  self._tokenize_destination(x['dst'])))
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

جرِّب MTNTDatasetBuilder من خلال إنشاء مثيل للعنصر GemmaTokenizer المخصّص مرة أخرى، ثم تطبيقه على مجموعة بيانات MTNT، وأخذ عيّنة من مثالَين:

tokenizer = GemmaTokenizer(vocab)

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108  10924    665  12302
  235341    108      2   4397  63011   1437  38696   1241      1      0]
 [     2  49688    736   1280   6987 235292    108  13835   1517 235265
     108      2  69875    540  19713 235265      1      0      0      0]
 [     2  49688    736   1280   6987 235292    108   6956   1586 235297
  235265    108      2  78368   1586 235297 235265      1      0      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True False]
 [False False False False False False False False False False False  True
   True  True  True  True  True False False False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True False False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108  18874 235341    108
       2 115905   6425   1241      1      0      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   7574   3356 235341
     108      2   7997  20707   1241      1      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   8703    665 235265
     108      2 235338 235303  90006  20133 235265      1      0      0]]
target_mask: [[False False False False False False False False False False  True  True
   True  True  True False False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True  True  True False False]]

ضبط النموذج

قبل البدء في ضبط نموذج Gemma، يجب ضبطه.

أولاً، عليك تحميل نقطة مراجعة نموذج Gemma وتنسيقها باستخدام الطريقة gemma.params.load_and_format_params:

params = params_lib.load_and_format_params(CKPT_PATH)

لتحميل الإعدادات الصحيحة تلقائيًا من نقطة مراجعة نموذج Gemma، استخدِم gemma.transformer.TransformerConfig. الوسيطة cache_size هي عدد الخطوات الزمنية في ذاكرة التخزين المؤقت لـ Gemma Transformer. بعد ذلك، يمكنك إنشاء مثيل لنموذج Gemma كـ model_2b باستخدام gemma.transformer.Transformer (التي تكتسب من flax.linen.Module).

config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30
)

model_2b = transformer_lib.Transformer(config=config_2b)

ضبط النموذج

في هذا القسم، سوف:

  • استخدِم الفئة gemma.transformer.Transformer لإنشاء دالة التمرير الأمامي ودالة الخسارة.
  • بناء متجهات قناع الموضع والانتباه للرموز المميزة
  • يمكنك إنشاء دالة لخطوات التدريب باستخدام Flax.
  • قم بإنشاء خطوة التحقق دون التمرير العكسي.
  • إنشاء حلقة التدريب.
  • ضبط نموذج Gemma

تحديد التمرير الأمامي ودالة الخسارة باستخدام الفئة gemma.transformer.Transformer. مكتسبة Gemma Transformer من flax.linen.Module، وتوفر طريقتين أساسيتين:

  • init: يؤدي هذا الخيار إلى إعداد مَعلمات النموذج.
  • apply: لتنفيذ الدالة __call__ في النموذج باستخدام مجموعة معيّنة من المَعلمات

    بما أنّك تعمل باستخدام أوزان Gemma المدرّبة مسبقًا، لن تحتاج إلى استخدام الدالة init.

def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """The forward pass and the loss function.

  Args:
    params: Model's input parameters.
    model: The Gemma transformer model to call.
    input_tokens: Input tokens sequence, shape [B, L].
    input_mask: Tokens to ignore when computing the loss, shape [B, L].
    positions: Relative position of each token, shape [B, L].
    attention_mask: Input attention mask, shape [B, L].

  Returns:
    The softmax cross-entropy loss for the next-token prediction task.
  """

  # The forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicted.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels to one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Define the normalization factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the negative log likelihood (NLL) loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

تتطلّب الفئة gemma.transformer.Transformer متّجهًا attention_mask وpositions بجانب كل إدخال. يمكنك إنشاء هذه القيم من خلال إنشاء دالة مخصّصة تستخدم Transformer.build_positions_from_mask وTransformer.make_causal_attn_mask:

def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

أنشئ الدالة train_step التي تؤدي إلى الانتقال العكسي وتحدّث معلَمات النموذج وفقًا لذلك، حيث يكون:

  • jax.value_and_grad مخصص لتقييم دالة الخسارة والتدرج أثناء التمرير للأمام والخلف.
  • optax.apply_updates مخصّص لتعديل المَعلمات.
def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example: TrainingInput):
  """Train step.

  Args:
    model: The Gemma transformer model.
    params: The model's input parameters.
    optimizer: The Optax optimizer to use.
    opt_state: The input optimizer's state.
    pad_id: ID of the pad token.
    example: Input batch.

  Returns:
    The training loss, the updated parameters, and the updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)

  # The forward and backward passes.
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example.input_tokens,
                                                             input_mask=example.target_mask,
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters.
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

إنشاء الدالة validation_step بدون الممر الخلفي:

def validation_step(model: transformer_lib.Transformer,
                    params,
                    pad_id: int,
                    example: TrainingInput,
                    ):
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
  val_loss = forward_and_loss_fn(params,
                                 model=model,
                                 input_tokens=example.input_tokens,
                                 input_mask=example.target_mask,
                                 positions=positions,
                                 attention_mask=attention_mask)
  return val_loss

حدِّد حلقة التدريب باستخدام optax.sgd لمُحسِّن SGD:

@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

def train_loop(
    model: transformer_lib.Transformer,
    params,
    dataset_builder: MTNTDatasetBuilder,
    training_cfg: TrainingConfig):

  # Apply `jax.jit` on the training step, making the whole loop much more efficient.
  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])

  # Apply `jax.jit` on the validation step.
  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])

  # To save memory, use the SGD optimizer instead of the usual Adam optimizer.
  # Note that for this specific example, SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  # Build the training dataset.
  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
                                               num_epochs=training_cfg.num_epochs)
  train_ds = train_ds.as_numpy_iterator()

  # Build the validation dataset, with a limited number of samples for this demo.
  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss=0

  # A first round of the validation loss.
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += compiled_validation_step(model,
                                          params,
                                          dataset_builder._tokenizer.pad_id,
                                          val_example)
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = compiled_train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=dataset_builder._tokenizer.pad_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += compiled_validation_step(model,
                                              params,
                                              dataset_builder._tokenizer.pad_id,
                                              val_example)
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

ابدأ في ضبط نموذج Gemma على عدد محدود من الخطوات (SEQ_SIZE) للتأكد من ملاءمة ذلك في الذاكرة:

SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=100)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    dataset_builder=dataset_builder,
                    training_cfg=training_cfg)
Start, validation loss: 10.647212982177734
STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336
STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848
STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459
STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975
STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245

يُفترض أن ينخفض عدد الخطوات المفقودة في كل من التدريب والتحقق من الصحة مع عدد الخطوات.

إنشاء sampler باستخدام gemma.sampler.Sampler وهو يستخدم نقطة فحص نموذج Gemma وأداة إنشاء الرموز المميّزة.

sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['params'],
)

يمكنك استخدام sampler لمعرفة ما إذا كان النموذج يمكنه تنفيذ الترجمة. الوسيطة total_generation_steps في gemma.sampler.Sampler هي عدد الخطوات التي يتم تنفيذها عند إنشاء ردّ. للتأكّد من تطابُق الإدخال مع تنسيق التدريب، استخدِم البادئة Translate this into French:\n مع حرف سطر جديد في نهايته. وهذا يحدد النموذج الذي يبدأ الترجمة.

sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=100,
    ).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]

مزيد من المعلومات