ai.google.dev에서 보기 | Google Colab에서 실행 | Vertex AI에서 열기 | GitHub에서 소스 보기 |
이 튜토리얼에서는 Google DeepMind의 recurrentgemma
라이브러리, JAX(고성능 수치 컴퓨팅 라이브러리), Flax(JAX 기반 신경망 라이브러리), Chex(안정적인 JAX 코드 작성을 위한 유틸리티 라이브러리인 Chex(안정적인 JAX 코드 작성을 위한 유틸리티 라이브러리), JAX-MT1 라이브러리로 신뢰할 수 있는 JAX 코드를 작성하는 라이브러리인 JAX 및 NT23){Machine-to-Text 라이브러리RecurrentGemmaOptax 이 노트북에서는 Flax를 직접 사용하지 않지만 Gemma를 만드는 데 Flax를 사용했습니다.
recurrentgemma
라이브러리는 JAX, Flax, Orbax (체크포인트와 같은 학습 유틸리티용 JAX 기반 라이브러리) 및 SentencePiece (tokenizer/detokenizer 라이브러리)로 작성되었습니다.
이 노트북은 T4 GPU를 사용하는 Google Colab에서 실행할 수 있습니다 (수정 > 노트북 설정으로 이동한 후 하드웨어 가속기에서 T4 GPU 선택).
설정
다음 섹션에서는 모델 액세스, API 키 가져오기, 노트북 런타임 구성 등 RecurrentGemma 모델을 사용하기 위해 노트북을 준비하는 단계를 설명합니다.
Gemma에 Kaggle 액세스 권한 설정하기
이 튜토리얼을 완료하려면 먼저 Gemma 설정과 비슷한 설정 안내를 따라야 하지만 몇 가지 예외가 있습니다.
- kaggle.com에서 Gemma 대신 RecurrentGemma에 액세스하세요.
- RecurrentGemma 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임을 선택하세요.
- Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.
RecurrentGemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.
환경 변수 설정하기
KAGGLE_USERNAME
및 KAGGLE_KEY
의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 비밀 액세스 제공에 동의해야 합니다.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
recurrentgemma
라이브러리 설치
현재 무료 Colab 하드웨어 가속으로는 이 노트북을 실행할 수 없습니다. Colab 종량제 또는 Colab Pro를 사용하는 경우 수정을 클릭합니다. 노트북 설정 > A100 GPU를 선택합니다. 저장하여 하드웨어 가속을 사용 설정합니다.
다음으로 github.com/google-deepmind/recurrentgemma
에서 Google DeepMind recurrentgemma
라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하면 일반적으로 무시해도 됩니다.
pip install -q git+https://github.com/google-deepmind/recurrentgemma.git
라이브러리 가져오기
이 노트북은 Flax (신경망용), 코어 JAX, SentencePiece (토큰화용), Chex (안정적인 JAX 코드를 작성하기 위한 유틸리티 라이브러리), Optax (그라데이션 처리 및 최적화 라이브러리), TensorFlow 데이터 세트를 사용합니다.
import pathlib
from typing import Any, Mapping, Iterator
import enum
import functools
import chex
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import sentencepiece as spm
from recurrentgemma import jax as recurrentgemma
RecurrentGemma 모델 로드
- 세 가지 인수를 사용하는
kagglehub.model_download
를 사용하여 RecurrentGemma 모델을 로드합니다.
handle
: Kaggle의 모델 핸들path
: (선택사항 문자열) 로컬 경로force_download
: (선택적 불리언) 모델을 강제로 다시 다운로드합니다.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
<ph type="x-smartling-placeholder">Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:50<00:00, 81.5MB/s] Extracting model files...</ph>
print('RECURRENTGEMMA_VARIANT:', RECURRENTGEMMA_VARIANT)
RECURRENTGEMMA_VARIANT: 2b-it
- 모델 가중치와 tokenizer의 위치를 확인한 다음 경로 변수를 설정합니다. tokenizer 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.
tokenizer.model
파일은/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
에 있습니다.- 모델 체크포인트는
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
에 있습니다.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
MTNT 데이터 세트 및 Gemma tokenizer 로드 및 준비
TensorFlow 데이터 세트에서 제공되는 MTNT (Machine Translation of Noisy Text) 데이터 세트를 사용합니다.
MTNT 데이터 세트의 영어-프랑스어 데이터 세트 부분을 다운로드한 다음 두 가지 예시를 샘플링합니다. 데이터 세트의 각 샘플에는 src
라는 두 가지 항목이 있습니다. 하나는 영어 문장이고 및 dst
: 상응하는 프랑스어 번역입니다.
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/35692 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-train.tfrecord*...: 0%| … Generating test examples...: 0%| | 0/1020 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-test.tfrecord*...: 0%| |… Generating valid examples...: 0%| | 0/811 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-valid.tfrecord*...: 0%| … Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data. Example 0: dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".' src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.' Example 1: dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?" src: b'Is Kameron a Little Salty About Her Lack of Air Time?'
sentencepiece.SentencePieceProcessor
를 사용하여 구성된 Gemma tokenizer를 로드합니다.
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
영어에서 프랑스어로 번역하는 작업의 SentencePieceProcessor
을 맞춤설정합니다. RecurrentGemma (Griffin) 모델의 영어 부분을 미세 조정하게 되므로 다음과 같은 몇 가지 조정이 필요합니다.
입력 접두사: 각 입력에 공통 접두사를 추가하면 번역 작업을 알립니다. 예를 들어
Translate this into French: [INPUT_SENTENCE]
와 같은 접두사가 포함된 프롬프트를 사용할 수 있습니다.번역 시작 접미사: 각 프롬프트의 끝에 접미사를 추가하면 Gemma 모델에 정확히 언제 번역 프로세스를 시작할지 지시하게 됩니다. 새 줄로 작업을 실행할 수 있습니다.
언어 모델 토큰: RecurrentGemma (Griffin) 모델은 '시퀀스의 시작'을 예상합니다. 각 시퀀스의 시작 부분에 있습니다. 마찬가지로 '시퀀스의 끝'을 추가해야 각 학습 예의 끝부분에 있습니다.
다음과 같이 SentencePieceProcessor
주위에 맞춤 래퍼를 빌드합니다.
class GriffinTokenizer:
"""A custom wrapper around a SentencePieceProcessor."""
def __init__(self, spm_processor: spm.SentencePieceProcessor):
self._spm_processor = spm_processor
@property
def pad_id(self) -> int:
"""Fast access to the pad ID."""
return self._spm_processor.pad_id()
def tokenize(
self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True,
) -> jax.Array:
"""
A tokenization function.
Args:
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_eos: If True, add an end of sentence token at the end of the output
sequence.
Returns:
Tokens corresponding to the input string.
"""
int_list = [self._spm_processor.bos_id()]
int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
if add_eos:
int_list.append(self._spm_processor.eos_id())
return jnp.array(int_list, dtype=jnp.int32)
def tokenize_tf_op(
self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True,
) -> tf.Tensor:
"""A TensforFlow operator for the `tokenize` function."""
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded
def to_string(self, tokens: jax.Array) -> str:
"""Convert an array of tokens to a string."""
return self._spm_processor.EncodeIds(tokens.tolist())
새 커스텀 GriffinTokenizer
를 인스턴스화한 후 MTNT 데이터 세트의 소규모 샘플에 적용해 보세요.
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(
example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False
)
def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example, add_eos=True)
tokenizer = GriffinTokenizer(vocab)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {
'src': tokenize_source(tokenizer, x['src']),
'dst': tokenize_destination(tokenizer, x['dst'])
})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Example 0: src: [ 2 49688 736 1280 6987 235292 108 651 2778 576 1080 104745 11982 5736 832 8995 901 780 3547 665 575 573 4589 235369 2778 235265 108] dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840 581 683 111452 581 533 235303 9776 4108 2459 679 485 235303 479 6728 579 1806 2499 709 29653 581 533 235303 101323 16054 1] Example 1: src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477 476 11709 230461 8045 3636 40268 576 4252 4897 235336 108] dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809 581 2032 69972 581 11495 1305 533 235303 65978 1654 1]
전체 MTNT 데이터 세트를 위한 데이터 로더를 빌드합니다.
@chex.dataclass(frozen=True)
class TrainingInput:
# Input tokens provided to the model.
input_tokens: jax.Array
# A mask that determines which tokens contribute to the target loss
# calculation.
target_mask: jax.Array
class DatasetSplit(enum.Enum):
TRAIN = 'train'
VALIDATION = 'valid'
class MTNTDatasetBuilder:
"""A data loader for the MTNT dataset."""
N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}
BUFFER_SIZE_SHUFFLE = 10_000
TRANSLATION_PREFIX = 'Translate this into French:\n'
TRANSLATION_SUFFIX = '\n'
def __init__(self,
tokenizer : GriffinTokenizer,
max_seq_len: int):
"""A constructor.
Args:
tokenizer: The tokenizer to use.
max_seq_len: The size of each sequence in a given batch.
"""
self._tokenizer = tokenizer
self._base_data = {
DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
}
self._max_seq_len = max_seq_len
def _tokenize_source(self, example: tf.Tensor):
"""A tokenization function for the source."""
return self._tokenizer.tokenize_tf_op(
example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
add_eos=False
)
def _tokenize_destination(self, example: tf.Tensor):
"""A tokenization function for the French translation."""
return self._tokenizer.tokenize_tf_op(example, add_eos=True)
def _pad_up_to_max_len(self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
"""Pad the given tensor up to sequence length of a batch."""
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(
input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
)
def _to_training_input(
self,
src_tokens: jax.Array,
dst_tokens: jax.Array,
) -> TrainingInput:
"""Build a training input from a tuple of source and destination tokens."""
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)
# You want to prevent the model from updating based on the source (input)
# tokens. To achieve this, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)
# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)
# You don't want to perform the backward on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)
return TrainingInput(input_tokens=tokens, target_mask=mask)
def get_train_dataset(self, batch_size: int, num_epochs: int):
"""Build the training dataset."""
# Tokenize each sample.
ds = self._base_data[DatasetSplit.TRAIN].map(
lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst']))
)
# Convert them to training inputs.
ds = ds.map(lambda x, y: self._to_training_input(x, y))
# Remove the samples which are too long.
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
# Shuffle the dataset.
ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
# Repeat if necessary.
ds = ds.repeat(num_epochs)
# Build batches.
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_validation_dataset(self, batch_size: int):
"""Build the validation dataset."""
# Same as the training dataset, but no shuffling and no repetition
ds = self._base_data[DatasetSplit.VALIDATION].map(
lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst']))
)
ds = ds.map(lambda x, y: self._to_training_input(x, y))
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
커스텀 GriffinTokenizer
를 다시 인스턴스화한 후 MTNT 데이터 세트에 적용하고 두 가지 예를 샘플링하여 MTNTDatasetBuilder
를 사용해 봅니다.
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> Example 0: input_tokens: [[ 2 49688 736 1280 6987 235292 108 12583 665 235265 108 2 6151 94975 1320 6238 235265 1 0 0] [ 2 49688 736 1280 6987 235292 108 4899 29960 11270 108282 235265 108 2 4899 79025 11270 108282 1 0] [ 2 49688 736 1280 6987 235292 108 26620 235265 108 2 26620 235265 1 0 0 0 0 0 0]] target_mask: [[False False False False False False False False False False False True True True True True True True False False] [False False False False False False False False False False False False False True True True True True True False] [False False False False False False False False False False True True True True False False False False False False]] Example 1: input_tokens: [[ 2 49688 736 1280 6987 235292 108 527 5174 1683 235336 108 2 206790 581 20726 482 2208 1654 1] [ 2 49688 736 1280 6987 235292 108 28484 235256 235336 108 2 120500 13832 1654 1 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 235324 235304 2705 235265 108 2 235324 235304 19963 235265 1 0 0]] target_mask: [[False False False False False False False False False False False False True True True True True True True True] [False False False False False False False False False False False True True True True True False False False False] [False False False False False False False False False False False False True True True True True True False False]]
모델 구성
Gemma 모델의 미세 조정을 시작하기 전에 먼저 구성해야 합니다.
recurrentgemma.jax.utils.load_parameters
메서드를 사용하여 RecurrentGemma (Griffin) 모델 체크포인트를 로드합니다.
params = recurrentgemma.load_parameters(CKPT_PATH, "single_device")
RecurrentGemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면 recurrentgemma.GriffinConfig.from_flax_params_or_variables
를 사용합니다.
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
recurrentgemma.jax.Griffin
를 사용하여 Griffin 모델을 인스턴스화합니다.
model = recurrentgemma.Griffin(config)
RecurrentGemma 모델 체크포인트/가중치 및 tokenizer 위에 recurrentgemma.jax.Sampler
가 있는 sampler
를 만들어 모델이 변환을 수행할 수 있는지 확인합니다.
sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)
모델 미세 조정
이 섹션에서 수행할 작업은 다음과 같습니다.
gemma.transformer.Transformer
클래스를 사용하여 정방향 전달 및 손실 함수를 만듭니다.- 토큰의 위치 및 어텐션 마스크 벡터 빌드
- Flax를 사용하여 학습 단계 함수를 빌드합니다.
- 역방향 전달 없이 유효성 검사 단계를 빌드합니다.
- 학습 루프를 만듭니다.
- Gemma 모델을 미세 조정합니다.
recurrentgemma.jax.griffin.Griffin
를 사용하여 정방향 전달과 손실 함수 정의
클래스에 대해 자세히 알아보세요. RecurrentGemma Griffin
는 flax.linen.Module
에서 상속되며 다음과 같은 두 가지 필수 메서드를 제공합니다.
init
: 모델의 매개변수를 초기화합니다.apply
: 지정된 매개변수 집합을 사용하여 모델의__call__
함수를 실행합니다.
선행 학습된 Gemma 가중치로 작업하고 있으므로 init
함수를 사용할 필요가 없습니다.
def forward_and_loss_fn(
params,
*,
model: recurrentgemma.Griffin,
input_tokens: jax.Array, # Shape [B, L]
input_mask: jax.Array, # Shape [B, L]
positions: jax.Array, # Shape [B, L]
) -> jax.Array:
"""Forward pass and loss function.
Args:
params: model's input parameters.
model: Griffin model to call.
input_tokens: input tokens sequence, shape [B, L].
input_mask: tokens to ignore when computing the loss, shape [B, L].
positions: relative position of each token, shape [B, L].
Returns:
Softmax cross-entropy loss for the next-token prediction task.
"""
batch_size = input_tokens.shape[0]
# Forward pass on the input data.
# No attention cache is needed here.
# Exclude the last step as it does not appear in the targets.
logits, _ = model.apply(
{"params": params},
tokens=input_tokens[:, :-1],
segment_pos=positions[:, :-1],
cache=None,
)
# Similarly, the first token cannot be predicteds.
target_tokens = input_tokens[:, 1:]
target_mask = input_mask[:, 1:]
# Convert the target labels into one-hot encoded vectors.
one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
# Don't update on unwanted tokens.
one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]
# Normalization factor.
norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)
# Return the negative log-likelihood loss (NLL) function.
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor
역방향 전달을 실행하고 그에 따라 모델의 매개변수를 업데이트하는 train_step
함수를 빌드합니다. 각 항목의 의미는 다음과 같습니다.
jax.value_and_grad
는 정방향 및 역방향 전달 중에 손실 함수와 경사를 평가하는 데 사용됩니다.optax.apply_updates
는 매개변수를 업데이트하는 데 사용됩니다.
Params = Mapping[str, Any]
def get_positions(example: jax.Array, pad_id : int) -> jax.Array:
"""Builds the position vector from the given tokens."""
pad_mask = example != pad_id
positions = jnp.cumsum(pad_mask, axis=-1)
# Subtract one for all positions from the first valid one as they are
# 0-indexed
positions = positions - (positions >= 1)
return positions
@functools.partial(
jax.jit,
static_argnames=['model', 'optimizer'],
donate_argnames=['params', 'opt_state'],
)
def train_step(
model: recurrentgemma.Griffin,
params: Params,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
pad_id: int,
example: TrainingInput,
) -> tuple[jax.Array, Params, optax.OptState]:
"""The train step.
Args:
model: The RecurrentGemma (Griffin) model.
params: The model's input parameters.
optimizer: The Optax optimizer to use.
opt_state: The input optimizer's state.
pad_id: The ID of the pad token.
example: The input batch.
Returns:
Training loss, updated parameters, updated optimizer state.
"""
positions = get_positions(example.input_tokens, pad_id)
# Forward and backward passes.
train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
)
# Update the parameters.
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return train_loss, params, opt_state
역방향 전달 없이 validation_step
함수를 빌드합니다.
@functools.partial(jax.jit, static_argnames=['model'])
def validation_step(
model: recurrentgemma.Griffin,
params: Params,
pad_id: int,
example: TrainingInput,
) -> jax.Array:
return forward_and_loss_fn(
params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=get_positions(example.input_tokens, pad_id),
)
학습 루프를 정의합니다.
def train_loop(
model: recurrentgemma.Griffin,
params: Params,
optimizer: optax.GradientTransformation,
train_ds: Iterator[TrainingInput],
validation_ds: Iterator[TrainingInput],
num_steps: int | None = None,
eval_every_n: int = 20,
):
opt_state = jax.jit(optimizer.init)(params)
step_counter = 0
avg_loss=0
# The first round of the validation loss.
n_steps_eval = 0
eval_loss = 0
for val_example in validation_ds.as_numpy_iterator():
eval_loss += validation_step(
model, params, dataset_builder._tokenizer.pad_id, val_example
)
n_steps_eval += 1
print(f"Start, validation loss: {eval_loss/n_steps_eval}")
for train_example in train_ds:
train_loss, params, opt_state = train_step(
model=model,
params=params,
optimizer=optimizer,
opt_state=opt_state,
pad_id=dataset_builder._tokenizer.pad_id,
example=train_example,
)
step_counter += 1
avg_loss += train_loss
if step_counter % eval_every_n == 0:
eval_loss = 0
n_steps_eval = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += validation_step(
model,
params,
dataset_builder._tokenizer.pad_id,
val_example,
)
n_steps_eval +=1
avg_loss /= eval_every_n
eval_loss /= n_steps_eval
print(f"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}")
avg_loss=0
if num_steps is not None and step_counter > num_steps:
break
return params
여기서 (Optax) 옵티마이저를 선택해야 합니다. 메모리가 작은 기기의 경우 메모리 공간이 훨씬 적으므로 SGD를 사용해야 합니다. 최상의 미세 조정 성능을 얻으려면 Adam-W를 사용해 보세요. 이 예시에서는 이 노트북의 특정 작업에 대한 각 옵티마이저의 최적의 초매개변수가 2b-it
체크포인트에 제공됩니다.
def griffin_weight_decay_mask(params_like: optax.Params) -> Any:
# Don't put weight decay on the RGLRU, the embeddings and any biases
def enable_weight_decay(path: list[Any], _: Any) -> bool:
# Parameters in the LRU and embedder
path = [dict_key.key for dict_key in path]
if 'rg_lru' in path or 'embedder' in path:
return False
# All biases and scales
if path[-1] in ('b', 'scale'):
return False
return True
return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)
optimizer_choice = "sgd"
if optimizer_choice == "sgd":
optimizer = optax.sgd(learning_rate=1e-3)
num_steps = 300
elif optimizer_choice == "adamw":
optimizer = optax.adamw(
learning_rate=1e-4,
b2=0.96,
eps=1e-8,
weight_decay=0.1,
mask=griffin_weight_decay_mask,
)
num_steps = 100
else:
raise ValueError(f"Unknown optimizer: {optimizer_choice}")
학습 데이터 세트와 검증 데이터 세트를 준비합니다.
# Choose a small sequence length size, so that everything fits in memory.
num_epochs = 1
batch_size = 1
sequence_length = 32
# Make the dataset builder.
tokenizer = GriffinTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)
# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(
batch_size=batch_size,
num_epochs=num_epochs,
).as_numpy_iterator()
# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(
batch_size=batch_size,
).take(50)
제한된 수의 단계 (num_steps
)로 RecurrentGemma (Griffin) 모델의 미세 조정을 시작합니다.
trained_params = train_loop(
model=model,
params=params,
optimizer=optimizer,
train_ds=train_ds,
validation_ds=validation_ds,
num_steps=num_steps,
)
Start, validation loss: 7.894117832183838 /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,33]), ShapedArray(bool[1,33]), ShapedArray(int32[], weak_type=True). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" STEP 20 training loss: 4.592616081237793 - eval loss: 2.847407102584839 STEP 40 training loss: 2.7537424564361572 - eval loss: 2.9258534908294678 STEP 60 training loss: 2.835618257522583 - eval loss: 2.4382340908050537 STEP 80 training loss: 2.6322107315063477 - eval loss: 2.3696839809417725 STEP 100 training loss: 1.8703256845474243 - eval loss: 2.355681896209717 STEP 120 training loss: 2.7280433177948 - eval loss: 2.4059958457946777 STEP 140 training loss: 2.3047447204589844 - eval loss: 2.083082914352417 STEP 160 training loss: 2.3432137966156006 - eval loss: 2.095074415206909 STEP 180 training loss: 2.1081202030181885 - eval loss: 2.006460189819336 STEP 200 training loss: 2.5359647274017334 - eval loss: 1.9667452573776245 STEP 220 training loss: 2.202195644378662 - eval loss: 1.9440618753433228 STEP 240 training loss: 2.756615400314331 - eval loss: 2.1073737144470215 STEP 260 training loss: 2.5128934383392334 - eval loss: 2.117241859436035 STEP 280 training loss: 2.73045015335083 - eval loss: 1.9159646034240723 STEP 300 training loss: 2.0918595790863037 - eval loss: 1.9742532968521118
각 걸음 수가 계산될 때마다 학습 손실과 검증 손실이 모두 감소해야 합니다.
입력이 학습 형식과 일치하도록 하려면 끝에 Translate this into French:\n
접두사와 줄바꿈 문자를 사용해야 합니다. 그러면 모델에 번역을 시작하라는 신호를 보내게 됩니다.
sampler.params = trained_params
output = sampler(
["Translate this into French:\nHello, my name is Morgane.\n"],
total_generation_steps=100,
)
print(output.text[0])
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,16]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Mais je m'appelle Morgane.
자세히 알아보기
recurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
,recurrentgemma.jax.Sampler
등 이 튜토리얼에서 사용한 메서드 및 모듈의 docstring이 포함된 Google DeepMindrecurrentgemma
라이브러리에 관해 자세히 알아볼 수 있습니다.- core JAX, Flax, Chex, Optax, Orbax 라이브러리에는 자체 문서 사이트가 있습니다.
sentencepiece
tokenizer/detokenizer 문서는 Google의sentencepiece
GitHub 저장소를 확인하세요.kagglehub
문서는 Kaggle의kagglehub
GitHub 저장소에서README.md
를 확인하세요.- Google Cloud Vertex AI에서 Gemma 모델을 사용하는 방법 알아보기
- Google Cloud TPU (v3-8 이상)를 사용하는 경우 최신
jax[tpu]
패키지 (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)로 업데이트하고 런타임을 다시 시작한 다음jax
및jaxlib
버전이 일치하는지 (!pip list | grep jax
) 확인합니다. 이렇게 하면jaxlib
및jax
버전 불일치로 인해 발생할 수 있는RuntimeError
을 방지할 수 있습니다. JAX 설치에 대한 자세한 안내는 JAX 문서를 참조하세요. - RecurrentGemma: Moving Past Transformers 자세한 내용은 Google DeepMind의 효율적인 개방형 언어 모델 보고서를 참조하세요.
- Griffin: Mixing Gated Linear Recurrences with RecurrentGemma에서 사용하는 모델 아키텍처에 대해 자세히 알아보려면 Google DeepMind의 효율적인 언어 모델에 대한 로컬 관심 문서를 참조하세요.