Este tutorial demonstra como ajustar o modelo Instruct 2B da RecurrentGemma para uma tarefa de tradução do inglês para o francês usando a biblioteca recurrentgemma
do Google DeepMind, o JAX (uma biblioteca de computação numérica de alto desempenho), a Flax (a biblioteca de rede neural baseada em JAX), o Chex (uma biblioteca de utilitários para escrever código JAX confiável), a Optax (a biblioteca de processamento e otimização de gradiente baseada em JAX) e o conjunto de dados MTNT (Tradução automática de texto com ruído). Embora o Flax não seja usado diretamente neste notebook, ele foi usado para criar o Gemma.
A biblioteca recurrentgemma
foi escrita com JAX, Flax, Orbax (uma biblioteca baseada em JAX para utilitários de treinamento, como checkpointing) e SentencePiece (uma biblioteca de tokenizer/detokenizer).
Este notebook pode ser executado no Google Colab com a GPU T4. Acesse Editar > Configurações do notebook > Acelerador de hardware e selecione GPU T4.
Configuração
As seções a seguir explicam as etapas para preparar um notebook para usar um modelo RecurrentGemma, incluindo o acesso ao modelo, a obtenção de uma chave de API e a configuração do ambiente de execução do notebook.
Configurar o acesso do Kaggle para o Gemma
Para concluir este tutorial, primeiro siga as instruções de configuração semelhantes à configuração do Gemma, com algumas exceções:
- Acesse o RecurrentGemma (em vez do Gemma) em kaggle.com.
- Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo RecurrentGemma.
- Gere e configure um nome de usuário e uma chave de API do Kaggle.
Depois de concluir a configuração do RecurrentGemma, passe para a próxima seção, em que você vai definir variáveis de ambiente para seu ambiente do Colab.
Defina as variáveis de ambiente
Defina as variáveis de ambiente para KAGGLE_USERNAME
e KAGGLE_KEY
. Quando receber a mensagem "Conceder acesso?", aceite para fornecer acesso ao segredo.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instalar a biblioteca recurrentgemma
No momento, a aceleração de hardware sem custo financeiro do Colab é insufficient para executar este notebook. Se você estiver usando o pagamento por uso ou o Colab Pro, clique em Editar > Configurações do notebook > Selecione GPU A100 > Salvar para ativar a aceleração de hardware.
Em seguida, instale a biblioteca recurrentgemma
do Google DeepMind em github.com/google-deepmind/recurrentgemma
. Se você receber um erro sobre o "resolvedor de dependências do pip", geralmente é possível ignorá-lo.
pip install -q git+https://github.com/google-deepmind/recurrentgemma.git
Importar bibliotecas
Este bloco de notas usa o Flax (para redes neurais), o JAX principal, o SentencePiece (para tokenização), o Chex (uma biblioteca de utilitários para escrever código JAX confiável), o Optax (a biblioteca de processamento e otimização de gradiente) e os conjuntos de dados do TensorFlow.
import pathlib
from typing import Any, Mapping, Iterator
import enum
import functools
import chex
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import sentencepiece as spm
from recurrentgemma import jax as recurrentgemma
Carregar o modelo RecurrentGemma
- Carregue o modelo RecurrentGemma com
kagglehub.model_download
, que usa três argumentos:
handle
: o identificador do modelo do Kagglepath
: (string opcional) o caminho localforce_download
: (booleano opcional) força o novo download do modelo.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:50<00:00, 81.5MB/s] Extracting model files...
print('RECURRENTGEMMA_VARIANT:', RECURRENTGEMMA_VARIANT)
RECURRENTGEMMA_VARIANT: 2b-it
- Verifique o local dos pesos do modelo e do tokenizer e defina as variáveis de caminho. O diretório do tokenizer vai estar no diretório principal em que você fez o download do modelo, e os pesos do modelo vão estar em um subdiretório. Exemplo:
- O arquivo
tokenizer.model
vai estar em/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
. - O ponto de verificação do modelo será
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Carregar e preparar o conjunto de dados MTNT e o tokenizer Gemma
Você vai usar o conjunto de dados MTNT (Tradução automática de texto com ruído), disponível em TensorFlow Datasets.
Faça o download da parte do conjunto de dados de inglês para francês do conjunto de dados MTNT e selecione dois exemplos. Cada amostra no conjunto de dados contém duas entradas: src
: a frase original em inglês e dst
: a tradução francesa correspondente.
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/35692 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-train.tfrecord*...: 0%| … Generating test examples...: 0%| | 0/1020 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-test.tfrecord*...: 0%| |… Generating valid examples...: 0%| | 0/811 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-valid.tfrecord*...: 0%| … Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data. Example 0: dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".' src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.' Example 1: dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?" src: b'Is Kameron a Little Salty About Her Lack of Air Time?'
Carregue o tokenizer do Gemma, criado usando sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Personalize o SentencePieceProcessor
para a tarefa de tradução do inglês para o francês. Como você vai ajustar a parte em inglês do modelo RecurrentGemma (Griffin), é necessário fazer alguns ajustes, como:
O prefixo de entrada: a adição de um prefixo comum a cada entrada indica a tarefa de tradução. Por exemplo, você pode usar um comando com um prefixo como
Translate this into French: [INPUT_SENTENCE]
.O sufixo de início da tradução: adicionar um sufixo ao final de cada comando instrui o modelo Gemma exatamente quando iniciar o processo de tradução. Uma nova linha deve resolver o problema.
Tokens de modelo de linguagem: os modelos RecurrentGemma (Griffin) esperam um token "início da sequência" no início de cada sequência. Da mesma forma, é necessário adicionar um token "fim da sequência" ao final de cada exemplo de treinamento.
Crie um wrapper personalizado em torno do SentencePieceProcessor
da seguinte maneira:
class GriffinTokenizer:
"""A custom wrapper around a SentencePieceProcessor."""
def __init__(self, spm_processor: spm.SentencePieceProcessor):
self._spm_processor = spm_processor
@property
def pad_id(self) -> int:
"""Fast access to the pad ID."""
return self._spm_processor.pad_id()
def tokenize(
self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True,
) -> jax.Array:
"""
A tokenization function.
Args:
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_eos: If True, add an end of sentence token at the end of the output
sequence.
Returns:
Tokens corresponding to the input string.
"""
int_list = [self._spm_processor.bos_id()]
int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
if add_eos:
int_list.append(self._spm_processor.eos_id())
return jnp.array(int_list, dtype=jnp.int32)
def tokenize_tf_op(
self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True,
) -> tf.Tensor:
"""A TensforFlow operator for the `tokenize` function."""
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded
def to_string(self, tokens: jax.Array) -> str:
"""Convert an array of tokens to a string."""
return self._spm_processor.EncodeIds(tokens.tolist())
Para testar, instancie o novo GriffinTokenizer
personalizado e aplique-o a uma pequena amostra do conjunto de dados MTNT:
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(
example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False
)
def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example, add_eos=True)
tokenizer = GriffinTokenizer(vocab)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {
'src': tokenize_source(tokenizer, x['src']),
'dst': tokenize_destination(tokenizer, x['dst'])
})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Example 0: src: [ 2 49688 736 1280 6987 235292 108 651 2778 576 1080 104745 11982 5736 832 8995 901 780 3547 665 575 573 4589 235369 2778 235265 108] dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840 581 683 111452 581 533 235303 9776 4108 2459 679 485 235303 479 6728 579 1806 2499 709 29653 581 533 235303 101323 16054 1] Example 1: src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477 476 11709 230461 8045 3636 40268 576 4252 4897 235336 108] dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809 581 2032 69972 581 11495 1305 533 235303 65978 1654 1]
Crie um carregador de dados para todo o conjunto de dados MTNT:
@chex.dataclass(frozen=True)
class TrainingInput:
# Input tokens provided to the model.
input_tokens: jax.Array
# A mask that determines which tokens contribute to the target loss
# calculation.
target_mask: jax.Array
class DatasetSplit(enum.Enum):
TRAIN = 'train'
VALIDATION = 'valid'
class MTNTDatasetBuilder:
"""A data loader for the MTNT dataset."""
N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}
BUFFER_SIZE_SHUFFLE = 10_000
TRANSLATION_PREFIX = 'Translate this into French:\n'
TRANSLATION_SUFFIX = '\n'
def __init__(self,
tokenizer : GriffinTokenizer,
max_seq_len: int):
"""A constructor.
Args:
tokenizer: The tokenizer to use.
max_seq_len: The size of each sequence in a given batch.
"""
self._tokenizer = tokenizer
self._base_data = {
DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
}
self._max_seq_len = max_seq_len
def _tokenize_source(self, example: tf.Tensor):
"""A tokenization function for the source."""
return self._tokenizer.tokenize_tf_op(
example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
add_eos=False
)
def _tokenize_destination(self, example: tf.Tensor):
"""A tokenization function for the French translation."""
return self._tokenizer.tokenize_tf_op(example, add_eos=True)
def _pad_up_to_max_len(self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
"""Pad the given tensor up to sequence length of a batch."""
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(
input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
)
def _to_training_input(
self,
src_tokens: jax.Array,
dst_tokens: jax.Array,
) -> TrainingInput:
"""Build a training input from a tuple of source and destination tokens."""
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)
# You want to prevent the model from updating based on the source (input)
# tokens. To achieve this, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)
# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)
# You don't want to perform the backward on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)
return TrainingInput(input_tokens=tokens, target_mask=mask)
def get_train_dataset(self, batch_size: int, num_epochs: int):
"""Build the training dataset."""
# Tokenize each sample.
ds = self._base_data[DatasetSplit.TRAIN].map(
lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst']))
)
# Convert them to training inputs.
ds = ds.map(lambda x, y: self._to_training_input(x, y))
# Remove the samples which are too long.
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
# Shuffle the dataset.
ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
# Repeat if necessary.
ds = ds.repeat(num_epochs)
# Build batches.
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_validation_dataset(self, batch_size: int):
"""Build the validation dataset."""
# Same as the training dataset, but no shuffling and no repetition
ds = self._base_data[DatasetSplit.VALIDATION].map(
lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst']))
)
ds = ds.map(lambda x, y: self._to_training_input(x, y))
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
Teste o MTNTDatasetBuilder
instanciando o GriffinTokenizer
personalizado novamente, aplicando-o no conjunto de dados MTNT e fazendo a amostragem de dois exemplos:
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> Example 0: input_tokens: [[ 2 49688 736 1280 6987 235292 108 12583 665 235265 108 2 6151 94975 1320 6238 235265 1 0 0] [ 2 49688 736 1280 6987 235292 108 4899 29960 11270 108282 235265 108 2 4899 79025 11270 108282 1 0] [ 2 49688 736 1280 6987 235292 108 26620 235265 108 2 26620 235265 1 0 0 0 0 0 0]] target_mask: [[False False False False False False False False False False False True True True True True True True False False] [False False False False False False False False False False False False False True True True True True True False] [False False False False False False False False False False True True True True False False False False False False]] Example 1: input_tokens: [[ 2 49688 736 1280 6987 235292 108 527 5174 1683 235336 108 2 206790 581 20726 482 2208 1654 1] [ 2 49688 736 1280 6987 235292 108 28484 235256 235336 108 2 120500 13832 1654 1 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 235324 235304 2705 235265 108 2 235324 235304 19963 235265 1 0 0]] target_mask: [[False False False False False False False False False False False False True True True True True True True True] [False False False False False False False False False False False True True True True True False False False False] [False False False False False False False False False False False False True True True True True True False False]]
Configurar o modelo
Antes de começar a ajustar o modelo Gemma, você precisa fazer a configuração.
Carregue o checkpoint do modelo RecurrentGemma (Griffin) com o método recurrentgemma.jax.utils.load_parameters
:
params = recurrentgemma.load_parameters(CKPT_PATH, "single_device")
Para carregar automaticamente a configuração correta do ponto de verificação do modelo RecurrentGemma, use recurrentgemma.GriffinConfig.from_flax_params_or_variables
:
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
Instanciar o modelo Griffin com recurrentgemma.jax.Griffin
:
model = recurrentgemma.Griffin(config)
Crie um sampler
com recurrentgemma.jax.Sampler
sobre o checkpoint/pesos do modelo RecurrentGemma e o tokenizer para verificar se o modelo pode fazer a tradução:
sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)
Ajustar o modelo
Nesta seção, você:
- Use a classe
gemma.deprecated.transformer.Transformer
para criar a função de transmissão e perda. - Criar os vetores de máscara de posição e de atenção para tokens
- Crie uma função de etapa de treinamento com o Flax.
- Crie a etapa de validação sem a passagem para trás.
- Crie o ciclo de treinamento.
- Ajustar o modelo Gemma.
Defina a passagem direta e a função de perda usando a classe
recurrentgemma.jax.griffin.Griffin
. O Griffin
da RecurrentGemma herda de flax.linen.Module
e oferece dois métodos essenciais:
init
: inicializa os parâmetros do modelo.apply
: executa a função__call__
do modelo usando um determinado conjunto de parâmetros.
Como você está trabalhando com pesos de Gemma pré-treinados, não é necessário usar a função init
.
def forward_and_loss_fn(
params,
*,
model: recurrentgemma.Griffin,
input_tokens: jax.Array, # Shape [B, L]
input_mask: jax.Array, # Shape [B, L]
positions: jax.Array, # Shape [B, L]
) -> jax.Array:
"""Forward pass and loss function.
Args:
params: model's input parameters.
model: Griffin model to call.
input_tokens: input tokens sequence, shape [B, L].
input_mask: tokens to ignore when computing the loss, shape [B, L].
positions: relative position of each token, shape [B, L].
Returns:
Softmax cross-entropy loss for the next-token prediction task.
"""
batch_size = input_tokens.shape[0]
# Forward pass on the input data.
# No attention cache is needed here.
# Exclude the last step as it does not appear in the targets.
logits, _ = model.apply(
{"params": params},
tokens=input_tokens[:, :-1],
segment_pos=positions[:, :-1],
cache=None,
)
# Similarly, the first token cannot be predicteds.
target_tokens = input_tokens[:, 1:]
target_mask = input_mask[:, 1:]
# Convert the target labels into one-hot encoded vectors.
one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
# Don't update on unwanted tokens.
one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]
# Normalization factor.
norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)
# Return the negative log-likelihood loss (NLL) function.
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor
Crie a função train_step
que executa a passagem reversa e atualiza os parâmetros do modelo de acordo, em que:
jax.value_and_grad
é para avaliar a função de perda e os gradientes durante as transmissões para frente e para trás.optax.apply_updates
é para atualizar os parâmetros.
Params = Mapping[str, Any]
def get_positions(example: jax.Array, pad_id : int) -> jax.Array:
"""Builds the position vector from the given tokens."""
pad_mask = example != pad_id
positions = jnp.cumsum(pad_mask, axis=-1)
# Subtract one for all positions from the first valid one as they are
# 0-indexed
positions = positions - (positions >= 1)
return positions
@functools.partial(
jax.jit,
static_argnames=['model', 'optimizer'],
donate_argnames=['params', 'opt_state'],
)
def train_step(
model: recurrentgemma.Griffin,
params: Params,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
pad_id: int,
example: TrainingInput,
) -> tuple[jax.Array, Params, optax.OptState]:
"""The train step.
Args:
model: The RecurrentGemma (Griffin) model.
params: The model's input parameters.
optimizer: The Optax optimizer to use.
opt_state: The input optimizer's state.
pad_id: The ID of the pad token.
example: The input batch.
Returns:
Training loss, updated parameters, updated optimizer state.
"""
positions = get_positions(example.input_tokens, pad_id)
# Forward and backward passes.
train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
)
# Update the parameters.
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return train_loss, params, opt_state
Crie a função validation_step
sem a passagem para trás:
@functools.partial(jax.jit, static_argnames=['model'])
def validation_step(
model: recurrentgemma.Griffin,
params: Params,
pad_id: int,
example: TrainingInput,
) -> jax.Array:
return forward_and_loss_fn(
params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=get_positions(example.input_tokens, pad_id),
)
Defina o ciclo de treinamento:
def train_loop(
model: recurrentgemma.Griffin,
params: Params,
optimizer: optax.GradientTransformation,
train_ds: Iterator[TrainingInput],
validation_ds: Iterator[TrainingInput],
num_steps: int | None = None,
eval_every_n: int = 20,
):
opt_state = jax.jit(optimizer.init)(params)
step_counter = 0
avg_loss=0
# The first round of the validation loss.
n_steps_eval = 0
eval_loss = 0
for val_example in validation_ds.as_numpy_iterator():
eval_loss += validation_step(
model, params, dataset_builder._tokenizer.pad_id, val_example
)
n_steps_eval += 1
print(f"Start, validation loss: {eval_loss/n_steps_eval}")
for train_example in train_ds:
train_loss, params, opt_state = train_step(
model=model,
params=params,
optimizer=optimizer,
opt_state=opt_state,
pad_id=dataset_builder._tokenizer.pad_id,
example=train_example,
)
step_counter += 1
avg_loss += train_loss
if step_counter % eval_every_n == 0:
eval_loss = 0
n_steps_eval = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += validation_step(
model,
params,
dataset_builder._tokenizer.pad_id,
val_example,
)
n_steps_eval +=1
avg_loss /= eval_every_n
eval_loss /= n_steps_eval
print(f"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}")
avg_loss=0
if num_steps is not None and step_counter > num_steps:
break
return params
Aqui, você precisa escolher um otimizador (Optax). Para dispositivos com memória menor, use o SGD, que tem um consumo de memória muito menor. Para ter a melhor performance de ajuste fino, use Adam-W. Os hiperparâmetros ideais para cada otimizador para a tarefa específica neste notebook são fornecidos neste exemplo para o ponto de verificação 2b-it
.
def griffin_weight_decay_mask(params_like: optax.Params) -> Any:
# Don't put weight decay on the RGLRU, the embeddings and any biases
def enable_weight_decay(path: list[Any], _: Any) -> bool:
# Parameters in the LRU and embedder
path = [dict_key.key for dict_key in path]
if 'rg_lru' in path or 'embedder' in path:
return False
# All biases and scales
if path[-1] in ('b', 'scale'):
return False
return True
return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)
optimizer_choice = "sgd"
if optimizer_choice == "sgd":
optimizer = optax.sgd(learning_rate=1e-3)
num_steps = 300
elif optimizer_choice == "adamw":
optimizer = optax.adamw(
learning_rate=1e-4,
b2=0.96,
eps=1e-8,
weight_decay=0.1,
mask=griffin_weight_decay_mask,
)
num_steps = 100
else:
raise ValueError(f"Unknown optimizer: {optimizer_choice}")
Prepare os conjuntos de dados de treinamento e validação:
# Choose a small sequence length size, so that everything fits in memory.
num_epochs = 1
batch_size = 1
sequence_length = 32
# Make the dataset builder.
tokenizer = GriffinTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)
# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(
batch_size=batch_size,
num_epochs=num_epochs,
).as_numpy_iterator()
# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(
batch_size=batch_size,
).take(50)
Comece a ajustar o modelo RecurrentGemma (Griffin) em um número limitado de etapas (num_steps
):
trained_params = train_loop(
model=model,
params=params,
optimizer=optimizer,
train_ds=train_ds,
validation_ds=validation_ds,
num_steps=num_steps,
)
Start, validation loss: 7.894117832183838 /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,33]), ShapedArray(bool[1,33]), ShapedArray(int32[], weak_type=True). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" STEP 20 training loss: 4.592616081237793 - eval loss: 2.847407102584839 STEP 40 training loss: 2.7537424564361572 - eval loss: 2.9258534908294678 STEP 60 training loss: 2.835618257522583 - eval loss: 2.4382340908050537 STEP 80 training loss: 2.6322107315063477 - eval loss: 2.3696839809417725 STEP 100 training loss: 1.8703256845474243 - eval loss: 2.355681896209717 STEP 120 training loss: 2.7280433177948 - eval loss: 2.4059958457946777 STEP 140 training loss: 2.3047447204589844 - eval loss: 2.083082914352417 STEP 160 training loss: 2.3432137966156006 - eval loss: 2.095074415206909 STEP 180 training loss: 2.1081202030181885 - eval loss: 2.006460189819336 STEP 200 training loss: 2.5359647274017334 - eval loss: 1.9667452573776245 STEP 220 training loss: 2.202195644378662 - eval loss: 1.9440618753433228 STEP 240 training loss: 2.756615400314331 - eval loss: 2.1073737144470215 STEP 260 training loss: 2.5128934383392334 - eval loss: 2.117241859436035 STEP 280 training loss: 2.73045015335083 - eval loss: 1.9159646034240723 STEP 300 training loss: 2.0918595790863037 - eval loss: 1.9742532968521118
A perda de treinamento e a perda de validação devem ter diminuído com cada contagem de etapas.
Para garantir que a entrada corresponda ao formato de treinamento, use o prefixo Translate this into French:\n
e um caractere de nova linha no final. Isso sinaliza ao modelo para iniciar a tradução.
sampler.params = trained_params
output = sampler(
["Translate this into French:\nHello, my name is Morgane.\n"],
total_generation_steps=100,
)
print(output.text[0])
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,16]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Mais je m'appelle Morgane.
Saiba mais
- Saiba mais sobre a biblioteca
recurrentgemma
do Google DeepMind no GitHub (link em inglês), que contém docstrings de métodos e módulos usados neste tutorial, comorecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
erecurrentgemma.jax.Sampler
. - As seguintes bibliotecas têm os próprios sites de documentação: core JAX, Flax, Chex, Optax e Orbax.
- Para a documentação do
sentencepiece
tokenizer/detokenizer, consulte o repositório do GitHubsentencepiece
do Google. - Para conferir a documentação de
kagglehub
, consulteREADME.md
no repositório do GitHubkagglehub
do Kaggle. - Saiba como usar modelos Gemma com a Vertex AI do Google Cloud.
- Se você estiver usando TPUs do Google Cloud (v3-8 e versões mais recentes), atualize também para o pacote
jax[tpu]
mais recente (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), reinicie o ambiente de execução e verifique se as versõesjax
ejaxlib
correspondem (!pip list | grep jax
). Isso pode impedir oRuntimeError
que pode surgir devido à incompatibilidade de versõesjaxlib
ejax
. Para mais instruções de instalação do JAX, consulte os documentos do JAX. - Confira o artigo RecurrentGemma: Moving Past Transformers for Efficient Open Language Models (em inglês) do Google DeepMind.
- Leia o artigo Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models (em inglês) do Google DeepMind para saber mais sobre a arquitetura de modelos usada pelo RecurrentGemma.