JAX 및 Flax를 사용하여 Gemma 미세 조정

ai.google.dev에서 보기 Google Colab에서 실행 Vertex AI에서 열기 GitHub에서 소스 보기

개요

Gemma는 Google DeepMind Gemini의 연구 및 기술을 기반으로 하는 최첨단 개방형 대규모 언어 모델 제품군입니다. 이 튜토리얼에서는 Google DeepMind의 gemma 라이브러리, JAX (고성능 수치 계산 라이브러리), Flax (JAX 기반 신경망 라이브러리), Chex (신뢰할 수 있는 라이브러리 기반 JAX 코드 작성을 위한 유틸리티 라이브러리), {Optax1 (안정적인 라이브러리 기반 JAX 코드) 데이터 세트, MTy1 (NT2) 그라데이션 데이터 세트의 텍스트 변환, MTy1 및 MT2 기울기의 텍스트 변환을 사용하여 영어-프랑스어 번역 작업을 위해 Gemma 2B Instruct 모델을 미세 조정하는 방법을 보여줍니다. 이 노트북에서 Flax를 직접 사용하지는 않지만 Flax를 사용하여 Gemma를 만들었습니다.

gemma 라이브러리는 JAX, Flax, Orbax (체크포인팅과 같은 학습 유틸리티를 위한 JAX 기반 라이브러리) 및 SentencePiece (tokenizer/detokenizer 라이브러리)로 작성되었습니다.

설정

1. Gemma용 Kaggle 액세스 권한 설정하기

이 튜토리얼을 완료하려면 먼저 Gemma 설정의 설정 안내에 따라 다음 작업을 수행해야 합니다.

  • kaggle.com에서 Gemma에 액세스하세요.
  • Gemma 모델을 실행하기에 충분한 리소스가 포함된 Colab 런타임을 선택합니다.
  • Kaggle 사용자 이름과 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후에는 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

2. 환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 보안 비밀 액세스 권한을 제공하는 데 동의합니다.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma 라이브러리 설치

현재 이 노트북을 실행하기에는 무료 Colab 하드웨어 가속이 충분하지 않습니다. Colab 종량제 또는 Colab Pro를 사용하는 경우 수정 > 노트북 설정을 클릭하고 A100 GPU > 저장을 선택하여 하드웨어 가속을 사용 설정합니다.

다음으로 github.com/google-deepmind/gemma에서 Google DeepMind gemma 라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하는 경우 일반적으로 무시해도 됩니다.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. 라이브러리 가져오기

이 노트북은 Flax (신경망용), 핵심 JAX, SentencePiece (토큰화용), Chex (신뢰할 수 있는 JAX 코드 작성을 위한 유틸리티 라이브러리), TensorFlow 데이터 세트를 사용합니다.

import os
import enum
import re
import string

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

Gemma 모델 로드

다음 세 가지 인수를 사용하는 kagglehub.model_download로 Gemma 모델을 로드합니다.

  • handle: Kaggle의 모델 핸들
  • path: (선택사항 문자열) 로컬 경로
  • force_download: (부울 선택사항) 모델을 강제로 다시 다운로드합니다.
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2

모델 가중치와 tokenizer의 위치를 확인한 후 경로 변수를 설정하세요. tokenizer 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.

  • tokenizer.model 파일은 /LOCAL/PATH/TO/gemma/flax/2b-it/2에 있습니다.
  • 모델 체크포인트는 /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it에 위치합니다.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

MTNT 데이터 세트 및 Gemma tokenizer 로드 및 준비

TensorFlow 데이터 세트에서 제공되는 MTNT (Machine Translation of Noisy Text) 데이터 세트를 사용합니다.

MTNT 데이터 세트의 영어-프랑스어 데이터 세트 부분을 다운로드하고 두 가지 예를 샘플링합니다. 데이터 세트의 각 샘플에는 2개의 항목이 포함되어 있습니다. src는 원래 영어 문장이고 dst는 해당하는 프랑스어 번역입니다.

ds = tfds.load("mtnt/en-fr", split="train")

ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/35692 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...:   0%|          …
Generating test examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...:   0%|          |…
Generating valid examples...:   0%|          | 0/811 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...:   0%|          …
Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.
Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'

sentencepiece.SentencePieceProcessor를 사용하여 구성된 Gemma tokenizer를 로드합니다.

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

영어에서 프랑스어로 번역 작업에 맞게 SentencePieceProcessor를 맞춤설정합니다. Gemma 모델의 영어 부분을 세밀하게 조정할 것이므로 다음과 같은 몇 가지 조정이 필요합니다.

  • 입력 접두사: 각 입력에 공통 접두사를 추가하면 번역 작업임을 알 수 있습니다. 예를 들어 Translate this into French: [INPUT_SENTENCE]와 같은 접두사가 있는 프롬프트를 사용할 수 있습니다.

  • 번역 시작 접미사: 각 프롬프트 끝에 접미사를 추가하면 Gemma 모델이 언제 번역 프로세스를 시작할지 정확히 알 수 있습니다. 새 줄을 사용하여 작업을 수행합니다.

  • 언어 모델 토큰: 젬마 모델은 각 시퀀스의 시작 부분에 '시퀀스의 시작' 토큰을 예상하므로 각 학습 예의 끝에 '시퀀스의 끝' 토큰을 추가하는 것으로 충분합니다.

    다음과 같이 SentencePieceProcessor 주위에 맞춤 래퍼를 빌드합니다.

class GemmaTokenizer:

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad ID."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    The tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an "end of sentence" token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """A TensorFlow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

새 커스텀 GemmaTokenizer를 인스턴스화한 다음 MTNT 데이터 세트의 작은 샘플에 적용하여 테스트해 보세요.

tokenizer = GemmaTokenizer(vocab)

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  prefix='Translate this into French:\n',
                                  suffix='\n',
                                  add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  add_eos=True)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
                       'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]

전체 MTNT 데이터 세트의 데이터 로더를 빌드합니다.

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to the model.
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'

class MTNTDatasetBuilder:
  """The dataset builder for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692,
             DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(example,
                                          prefix=self.TRANSLATION_PREFIX,
                                          suffix=self.TRANSLATION_SUFFIX,
                                          add_eos=False)

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example,
                                          add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(input_tensor,
                  [[0, to_pad]],
                  mode='CONSTANT',
                  constant_values=pad_value,
                  )

  def _to_training_input(self,
                         src_tokens: jax.Array,
                         dst_tokens: jax.Array,
                         ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample.
    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
                                                             self._tokenize_destination(x['dst'])))

    # Convert the samples to training inputs.
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples that are too long.
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset.
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary.
    ds = ds.repeat(num_epochs)

    # Build batches.
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and no repetition.
    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
                                                                  self._tokenize_destination(x['dst'])))
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

커스텀 GemmaTokenizer를 다시 인스턴스화한 다음 MTNT 데이터 세트에 적용하고 다음 두 가지 예를 샘플링하여 MTNTDatasetBuilder를 사용해 보세요.

tokenizer = GemmaTokenizer(vocab)

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'>
Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108  10924    665  12302
  235341    108      2   4397  63011   1437  38696   1241      1      0]
 [     2  49688    736   1280   6987 235292    108  13835   1517 235265
     108      2  69875    540  19713 235265      1      0      0      0]
 [     2  49688    736   1280   6987 235292    108   6956   1586 235297
  235265    108      2  78368   1586 235297 235265      1      0      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True False]
 [False False False False False False False False False False False  True
   True  True  True  True  True False False False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True False False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108  18874 235341    108
       2 115905   6425   1241      1      0      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   7574   3356 235341
     108      2   7997  20707   1241      1      0      0      0      0]
 [     2  49688    736   1280   6987 235292    108   8703    665 235265
     108      2 235338 235303  90006  20133 235265      1      0      0]]
target_mask: [[False False False False False False False False False False  True  True
   True  True  True False False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True False False False False]
 [False False False False False False False False False False False  True
   True  True  True  True  True  True False False]]

모델 구성

Gemma 모델을 미세 조정하기 전에 먼저 구성해야 합니다.

먼저 gemma.params.load_and_format_params 메서드로 Gemma 모델 체크포인트를 로드하고 형식을 지정합니다.

params = params_lib.load_and_format_params(CKPT_PATH)

Gemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면 gemma.transformer.TransformerConfig를 사용하세요. cache_size 인수는 Gemma Transformer 캐시의 시간 단계 수입니다. 그런 다음 flax.linen.Module에서 상속되는 gemma.transformer.Transformer를 사용하여 Gemma 모델을 model_2b로 인스턴스화합니다.

config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30
)

model_2b = transformer_lib.Transformer(config=config_2b)

모델 미세 조정

이 섹션에서 수행할 작업은 다음과 같습니다.

  • gemma.transformer.Transformer 클래스를 사용하여 정방향 전달 및 손실 함수를 만듭니다.
  • 토큰의 위치 및 어텐션 마스크 벡터 빌드
  • Flax를 사용하여 학습 단계 함수를 빌드합니다.
  • 역방향 통과 없이 검증 단계를 빌드합니다.
  • 학습 루프를 만듭니다.
  • Gemma 모델 미세 조정

gemma.transformer.Transformer 클래스를 사용하여 정방향 패스와 손실 함수를 정의합니다. Gemma Transformerflax.linen.Module에서 상속되며 두 가지 필수 메서드를 제공합니다.

  • init: 모델의 매개변수를 초기화합니다.
  • apply: 지정된 매개변수 집합을 사용하여 모델의 __call__ 함수를 실행합니다.

    선행 학습된 Gemma 가중치로 작업하고 있으므로 init 함수를 사용할 필요가 없습니다.

def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """The forward pass and the loss function.

  Args:
    params: Model's input parameters.
    model: The Gemma transformer model to call.
    input_tokens: Input tokens sequence, shape [B, L].
    input_mask: Tokens to ignore when computing the loss, shape [B, L].
    positions: Relative position of each token, shape [B, L].
    attention_mask: Input attention mask, shape [B, L].

  Returns:
    The softmax cross-entropy loss for the next-token prediction task.
  """

  # The forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicted.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels to one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Define the normalization factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the negative log likelihood (NLL) loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

gemma.transformer.Transformer 클래스에는 각 입력과 함께 attention_maskpositions 벡터가 있어야 합니다. Transformer.build_positions_from_maskTransformer.make_causal_attn_mask를 사용하는 맞춤 함수를 만들어 생성할 수 있습니다.

def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

역방향 전달을 실행하고 그에 따라 모델의 매개변수를 업데이트하는 train_step 함수를 빌드합니다. 각 항목의 의미는 다음과 같습니다.

  • jax.value_and_grad는 정방향 및 역방향 전달 중에 손실 함수와 경사를 평가하는 데 사용됩니다.
  • optax.apply_updates는 매개변수를 업데이트하는 데 사용됩니다.
def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example: TrainingInput):
  """Train step.

  Args:
    model: The Gemma transformer model.
    params: The model's input parameters.
    optimizer: The Optax optimizer to use.
    opt_state: The input optimizer's state.
    pad_id: ID of the pad token.
    example: Input batch.

  Returns:
    The training loss, the updated parameters, and the updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)

  # The forward and backward passes.
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example.input_tokens,
                                                             input_mask=example.target_mask,
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters.
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

역방향 전달 없이 validation_step 함수를 빌드합니다.

def validation_step(model: transformer_lib.Transformer,
                    params,
                    pad_id: int,
                    example: TrainingInput,
                    ):
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
  val_loss = forward_and_loss_fn(params,
                                 model=model,
                                 input_tokens=example.input_tokens,
                                 input_mask=example.target_mask,
                                 positions=positions,
                                 attention_mask=attention_mask)
  return val_loss

SGD 옵티마이저에 optax.sgd를 사용하여 학습 루프를 정의합니다.

@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

def train_loop(
    model: transformer_lib.Transformer,
    params,
    dataset_builder: MTNTDatasetBuilder,
    training_cfg: TrainingConfig):

  # Apply `jax.jit` on the training step, making the whole loop much more efficient.
  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])

  # Apply `jax.jit` on the validation step.
  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])

  # To save memory, use the SGD optimizer instead of the usual Adam optimizer.
  # Note that for this specific example, SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  # Build the training dataset.
  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
                                               num_epochs=training_cfg.num_epochs)
  train_ds = train_ds.as_numpy_iterator()

  # Build the validation dataset, with a limited number of samples for this demo.
  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss=0

  # A first round of the validation loss.
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += compiled_validation_step(model,
                                          params,
                                          dataset_builder._tokenizer.pad_id,
                                          val_example)
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = compiled_train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=dataset_builder._tokenizer.pad_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += compiled_validation_step(model,
                                              params,
                                              dataset_builder._tokenizer.pad_id,
                                              val_example)
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

제한된 수의 단계 (SEQ_SIZE)로 Gemma 모델을 미세 조정하여 메모리에 맞도록 합니다.

SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=100)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    dataset_builder=dataset_builder,
                    training_cfg=training_cfg)
Start, validation loss: 10.647212982177734
STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336
STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848
STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459
STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975
STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245

걸음 수마다 학습 손실과 검증 손실이 모두 감소해야 합니다.

gemma.sampler.Samplersampler를 만듭니다. Gemma 모델 체크포인트와 tokenizer를 사용합니다.

sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['params'],
)

sampler를 사용하여 모델이 번역을 수행할 수 있는지 확인합니다. gemma.sampler.Samplertotal_generation_steps 인수는 응답을 생성할 때 수행된 단계 수입니다. 입력이 학습 형식과 일치하도록 하려면 Translate this into French:\n 프리픽스를 줄바꿈 문자와 함께 사용합니다. 그러면 모델이 번역을 시작하라는 신호를 보냅니다.

sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=100,
    ).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]

자세히 알아보기