Ver em ai.google.dev | Executar no Google Colab | Abrir na Vertex AI | Consulte o código-fonte no GitHub |
Visão geral
A Gemma é uma família de modelos de linguagem grandes, leves e modernos, com base na pesquisa e tecnologia do Google DeepMind Gemini. Este tutorial demonstra como ajustar o modelo Gemma 2B Instruct para uma tarefa de tradução do inglês para francês usando a biblioteca gemma
do Google DeepMind, JAX (uma biblioteca de computação numérica de alto desempenho), Flax (a biblioteca de rede neural baseada em JAX), Chex (uma biblioteca de utilitários para escrever código JAX confiável e processamento JAX), Optax-com base no conjunto de dados JAX1 e processamento de JAX (o conjunto de dados JAX JAX e NoNT1). Embora o Flax não seja usado diretamente neste notebook, ele foi usado para criar o Gemma.
A biblioteca gemma
foi escrita com JAX, Flax, Orbax (uma biblioteca baseada em JAX para utilitários de treinamento, como checkpointing) e SentencePiece (uma biblioteca de tokenizer/detokenizer).
Configuração
1. Configurar o acesso do Kaggle para o Gemma
Para concluir este tutorial, primeiro você precisa seguir as instruções de configuração em Gemma setup, que mostram como fazer o seguinte:
- Acesse o Gemma em kaggle.com.
- Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo Gemma.
- Gere e configure um nome de usuário e uma chave de API do Kaggle.
Depois de concluir a configuração do Gemma, passe para a próxima seção, em que você definirá variáveis para seu ambiente do Colab.
2. Defina as variáveis de ambiente
Defina as variáveis de ambiente para KAGGLE_USERNAME
e KAGGLE_KEY
. Quando a mensagem "Permitir acesso?" for exibida, concorde em fornecer acesso secreto.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. Instalar a biblioteca gemma
No momento, a aceleração de hardware sem custo financeiro do Colab é insuficiente para executar este notebook. Se você estiver usando o Colab Pay As Go ou o Colab Pro, clique em Editar > Configurações do notebook > Selecione A100 GPU > Salvar para ativar a aceleração de hardware.
Em seguida, você precisa instalar a biblioteca gemma
do Google DeepMind de github.com/google-deepmind/gemma
. Geralmente, se você receber um erro sobre o "resolvedor de dependências do pip", ele poderá ser ignorado.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. Importar bibliotecas
Este notebook usa o Flax (para redes neurais), o JAX principal, o SentencePiece (para tokenização), o Chex (uma biblioteca de utilitários para escrever códigos JAX confiáveis) e os conjuntos de dados do TensorFlow.
import os
import enum
import re
import string
import chex
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
Carregar o modelo Gemma
Carregue o modelo Gemma com kagglehub.model_download
, que usa três argumentos:
handle
: o identificador de modelo do Kagglepath
: (string opcional) o caminho localforce_download
: (booleano opcional) força o novo download do modelo
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download... 100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
Verifique o local dos pesos do modelo e do tokenizador e, em seguida, defina as variáveis de caminho. O diretório do tokenizador estará no diretório principal onde você fez o download do modelo, enquanto os pesos do modelo estarão em um subdiretório. Exemplo:
- O arquivo
tokenizer.model
estará em/LOCAL/PATH/TO/gemma/flax/2b-it/2
). - O checkpoint do modelo estará em
/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it
.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model
Carregar e preparar o conjunto de dados MTNT e o tokenizador Gemma
Você vai usar o conjunto de dados MTNT (tradução automática de texto com ruído), disponível nos conjuntos de dados do TensorFlow.
Faça o download da parte do conjunto de dados do inglês para o francês do MTNT e de dois exemplos. Cada amostra no conjunto de dados contém duas entradas: src
: a frase original em inglês e dst
: a tradução em francês correspondente.
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/35692 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...: 0%| … Generating test examples...: 0%| | 0/1020 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...: 0%| |… Generating valid examples...: 0%| | 0/811 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...: 0%| … Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data. Example 0: dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".' src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.' Example 1: dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?" src: b'Is Kameron a Little Salty About Her Lack of Air Time?'
Carregue o tokenizador Gemma, construído usando sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Personalize o SentencePieceProcessor
da tarefa de tradução do inglês para o francês. Como você ajustará a parte inglesa do modelo Gemma, será necessário fazer alguns ajustes, como:
Prefixo de entrada: adicionar um prefixo comum a cada entrada sinaliza a tarefa de tradução. Por exemplo, use um prompt com um prefixo como
Translate this into French: [INPUT_SENTENCE]
.O sufixo inicial da tradução: adicionar um sufixo ao final de cada comando instrui o modelo Gemma exatamente a começar o processo de tradução. Uma nova linha deve fazer o trabalho.
Tokens do modelo de linguagem: os modelos Gemma esperam um token de "início da sequência" no início de cada sequência. Portanto, adicionar um token de "fim de sequência" no final de cada exemplo de treinamento deve ser suficiente.
Crie um wrapper personalizado em torno do
SentencePieceProcessor
da seguinte maneira:
class GemmaTokenizer:
def __init__(self,
spm_processor: spm.SentencePieceProcessor):
self._spm_processor = spm_processor
@property
def pad_id(self) -> int:
"""Fast access to the pad ID."""
return self._spm_processor.pad_id()
def tokenize(self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True) -> jax.Array:
"""
The tokenization function.
Args:
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_eos: If True, add an "end of sentence" token at the end of the output
sequence.
Returns:
Tokens corresponding to the input string.
"""
int_list = [self._spm_processor.bos_id()]
int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
if add_eos:
int_list.append(self._spm_processor.eos_id())
return jnp.array(int_list, dtype=jnp.int32)
def tokenize_tf_op(self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True) -> tf.Tensor:
"""A TensorFlow operator for the tokenize function."""
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded
def to_string(self, tokens: jax.Array) -> str:
"""Convert an array of tokens to a string."""
return self._spm_processor.EncodeIds(tokens.tolist())
Faça um teste instanciando seu novo GemmaTokenizer
personalizado e aplicando-o em uma pequena amostra do conjunto de dados MTNT:
tokenizer = GemmaTokenizer(vocab)
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example,
add_eos=True)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Example 0: src: [ 2 49688 736 1280 6987 235292 108 651 2778 576 1080 104745 11982 5736 832 8995 901 780 3547 665 575 573 4589 235369 2778 235265 108] dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840 581 683 111452 581 533 235303 9776 4108 2459 679 485 235303 479 6728 579 1806 2499 709 29653 581 533 235303 101323 16054 1] Example 1: src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477 476 11709 230461 8045 3636 40268 576 4252 4897 235336 108] dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809 581 2032 69972 581 11495 1305 533 235303 65978 1654 1]
Crie um carregador para todo o conjunto de dados MTNT:
@chex.dataclass(frozen=True)
class TrainingInput:
# Input tokens provided to the model.
input_tokens: jax.Array
# A mask that determines which tokens contribute to the target loss
# calculation.
target_mask: jax.Array
class DatasetSplit(enum.Enum):
TRAIN = 'train'
VALIDATION = 'valid'
class MTNTDatasetBuilder:
"""The dataset builder for the MTNT dataset."""
N_ITEMS = {DatasetSplit.TRAIN: 35_692,
DatasetSplit.VALIDATION: 811}
BUFFER_SIZE_SHUFFLE = 10_000
TRANSLATION_PREFIX = 'Translate this into French:\n'
TRANSLATION_SUFFIX = '\n'
def __init__(self,
tokenizer : GemmaTokenizer,
max_seq_len: int):
"""Constructor.
Args:
tokenizer: Gemma tokenizer to use.
max_seq_len: size of each sequence in a given batch.
"""
self._tokenizer = tokenizer
self._base_data = {
DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
}
self._max_seq_len = max_seq_len
def _tokenize_source(self, example: tf.Tensor):
"""Tokenization function for the source."""
return self._tokenizer.tokenize_tf_op(example,
prefix=self.TRANSLATION_PREFIX,
suffix=self.TRANSLATION_SUFFIX,
add_eos=False)
def _tokenize_destination(self, example: tf.Tensor):
"""Tokenization function for the French translation."""
return self._tokenizer.tokenize_tf_op(example,
add_eos=True)
def _pad_up_to_max_len(self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
"""Pad the given tensor up to sequence length of a batch."""
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(input_tensor,
[[0, to_pad]],
mode='CONSTANT',
constant_values=pad_value,
)
def _to_training_input(self,
src_tokens: jax.Array,
dst_tokens: jax.Array,
) -> TrainingInput:
"""Build a training input from a tuple of source and destination tokens."""
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)
# To prevent the model from updating based on the source (input)
# tokens, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)
# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)
# Don't want to perform the backward pass on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)
return TrainingInput(input_tokens=tokens, target_mask=mask)
def get_train_dataset(self, batch_size: int, num_epochs: int):
"""Build the training dataset."""
# Tokenize each sample.
ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst'])))
# Convert the samples to training inputs.
ds = ds.map(lambda x, y: self._to_training_input(x, y))
# Remove the samples that are too long.
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
# Shuffle the dataset.
ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
# Repeat if necessary.
ds = ds.repeat(num_epochs)
# Build batches.
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_validation_dataset(self, batch_size: int):
"""Build the validation dataset."""
# Same steps as in `get_train_dataset`, but without shuffling and no repetition.
ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst'])))
ds = ds.map(lambda x, y: self._to_training_input(x, y))
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
Teste o MTNTDatasetBuilder
instanciando o GemmaTokenizer
personalizado novamente, aplicando-o ao conjunto de dados MTNT e fazendo a amostragem de dois exemplos:
tokenizer = GemmaTokenizer(vocab)
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> Example 0: input_tokens: [[ 2 49688 736 1280 6987 235292 108 10924 665 12302 235341 108 2 4397 63011 1437 38696 1241 1 0] [ 2 49688 736 1280 6987 235292 108 13835 1517 235265 108 2 69875 540 19713 235265 1 0 0 0] [ 2 49688 736 1280 6987 235292 108 6956 1586 235297 235265 108 2 78368 1586 235297 235265 1 0 0]] target_mask: [[False False False False False False False False False False False False True True True True True True True False] [False False False False False False False False False False False True True True True True True False False False] [False False False False False False False False False False False False True True True True True True False False]] Example 1: input_tokens: [[ 2 49688 736 1280 6987 235292 108 18874 235341 108 2 115905 6425 1241 1 0 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 7574 3356 235341 108 2 7997 20707 1241 1 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 8703 665 235265 108 2 235338 235303 90006 20133 235265 1 0 0]] target_mask: [[False False False False False False False False False False True True True True True False False False False False] [False False False False False False False False False False False True True True True True False False False False] [False False False False False False False False False False False True True True True True True True False False]]
Configurar o modelo
Antes de começar a ajustar o modelo Gemma, é preciso configurá-lo.
Primeiro, carregue e formate o checkpoint do modelo Gemma com o método gemma.params.load_and_format_params
:
params = params_lib.load_and_format_params(CKPT_PATH)
Para carregar automaticamente a configuração correta do checkpoint do modelo Gemma, use gemma.transformer.TransformerConfig
. O argumento cache_size
é o número de etapas de tempo no cache Transformer
do Gemma. Em seguida, instancie o modelo Gemma como model_2b
com gemma.transformer.Transformer
(herdado de flax.linen.Module
).
config_2b = transformer_lib.TransformerConfig.from_params(
params,
cache_size=30
)
model_2b = transformer_lib.Transformer(config=config_2b)
Ajustar o modelo
Nesta seção, você:
- Use a classe
gemma.transformer.Transformer
para criar a função de avanço e perda. - Criar os vetores de posição e máscara de atenção para tokens
- Crie uma função de etapa de treinamento com o Flax.
- Crie a etapa de validação sem a passagem para trás.
- Criar o loop de treinamento.
- Ajustar o modelo do Gemma.
Defina o passe para frente e a função de perda usando a classe gemma.transformer.Transformer
. A Transformer
da Gemma é herdada de flax.linen.Module
e oferece dois métodos essenciais:
init
: inicializa os parâmetros do modelo.apply
: executa a função__call__
do modelo usando um determinado conjunto de parâmetros.Como você está trabalhando com pesos Gemma pré-treinados, não é necessário usar a função
init
.
def forward_and_loss_fn(params,
*,
model: transformer_lib.Transformer,
input_tokens: jax.Array, # Shape [B, L]
input_mask: jax.Array, # Shape [B, L]
positions: jax.Array, # Shape [B, L]
attention_mask: jax.Array, # [B, L, L]
) -> jax.Array:
"""The forward pass and the loss function.
Args:
params: Model's input parameters.
model: The Gemma transformer model to call.
input_tokens: Input tokens sequence, shape [B, L].
input_mask: Tokens to ignore when computing the loss, shape [B, L].
positions: Relative position of each token, shape [B, L].
attention_mask: Input attention mask, shape [B, L].
Returns:
The softmax cross-entropy loss for the next-token prediction task.
"""
# The forward pass on the input data.
# No attention cache is needed here.
logits, _ = model.apply(
params,
input_tokens,
positions,
None, # Attention cache is None.
attention_mask,
)
# Exclude the last step as it does not appear in the targets.
logits = logits[0, :-1]
# Similarly, the first token cannot be predicted.
target_tokens = input_tokens[0, 1:]
target_mask = input_mask[0, 1:]
# Convert the target labels to one-hot encoded vectors.
one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
# Don't update on unwanted tokens.
one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]
# Define the normalization factor.
norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)
# Return the negative log likelihood (NLL) loss.
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor
A classe gemma.transformer.Transformer
requer um vetor attention_mask
e um positions
em cada entrada. É possível gerá-los criando uma função personalizada que usa Transformer.build_positions_from_mask
e Transformer.make_causal_attn_mask
:
def get_attention_mask_and_positions(example: jax.Array,
pad_id : int,
)-> tuple[jax.Array, jax.Array]:
"""Builds the position and attention mask vectors from the given tokens."""
pad_mask = example != pad_id
current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
return current_token_position, attention_mask
Crie a função train_step
que executa a passagem para trás e atualiza os parâmetros do modelo adequadamente, em que:
jax.value_and_grad
serve para avaliar a função de perda e os gradientes durante os passes para frente e para trás.optax.apply_updates
serve para atualizar os parâmetros.
def train_step(model: transformer_lib.Transformer,
params,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
pad_id: int,
example: TrainingInput):
"""Train step.
Args:
model: The Gemma transformer model.
params: The model's input parameters.
optimizer: The Optax optimizer to use.
opt_state: The input optimizer's state.
pad_id: ID of the pad token.
example: Input batch.
Returns:
The training loss, the updated parameters, and the updated optimizer state.
"""
# Build the position and attention mask vectors.
positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
# The forward and backward passes.
train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
attention_mask=attention_mask)
# Update the parameters.
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return train_loss, params, opt_state
Crie a função validation_step
sem o passe para trás:
def validation_step(model: transformer_lib.Transformer,
params,
pad_id: int,
example: TrainingInput,
):
positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
val_loss = forward_and_loss_fn(params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
attention_mask=attention_mask)
return val_loss
Defina o loop de treinamento usando optax.sgd
para o otimizador GDE:
@chex.dataclass(frozen=True)
class TrainingConfig:
learning_rate: float
num_epochs: int
eval_every_n: int
batch_size: int
max_steps: int | None = None
def train_loop(
model: transformer_lib.Transformer,
params,
dataset_builder: MTNTDatasetBuilder,
training_cfg: TrainingConfig):
# Apply `jax.jit` on the training step, making the whole loop much more efficient.
compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])
# Apply `jax.jit` on the validation step.
compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])
# To save memory, use the SGD optimizer instead of the usual Adam optimizer.
# Note that for this specific example, SGD is more than enough.
optimizer = optax.sgd(training_cfg.learning_rate)
opt_state = optimizer.init(params)
# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
num_epochs=training_cfg.num_epochs)
train_ds = train_ds.as_numpy_iterator()
# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
validation_ds = validation_ds.take(50)
n_steps = 0
avg_loss=0
# A first round of the validation loss.
n_steps_eval = 0
eval_loss = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += compiled_validation_step(model,
params,
dataset_builder._tokenizer.pad_id,
val_example)
n_steps_eval += 1
print(f"Start, validation loss: {eval_loss/n_steps_eval}")
for train_example in train_ds:
train_loss, params, opt_state = compiled_train_step(model=model,
params=params,
optimizer=optimizer,
opt_state=opt_state,
pad_id=dataset_builder._tokenizer.pad_id,
example=train_example)
n_steps += 1
avg_loss += train_loss
if n_steps % training_cfg.eval_every_n == 0:
eval_loss = 0
n_steps_eval = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += compiled_validation_step(model,
params,
dataset_builder._tokenizer.pad_id,
val_example)
n_steps_eval +=1
avg_loss /= training_cfg.eval_every_n
eval_loss /= n_steps_eval
print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
avg_loss=0
if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
break
return params
Comece a ajustar o modelo Gemma em um número limitado de etapas (SEQ_SIZE
) para garantir que isso se encaixe na memória:
SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
num_epochs=1,
eval_every_n=20,
batch_size=1,
max_steps=100)
params = train_loop(model=model_2b,
params={'params': params['transformer']},
dataset_builder=dataset_builder,
training_cfg=training_cfg)
Start, validation loss: 10.647212982177734 STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336 STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848 STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459 STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975 STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245
Tanto a perda do treinamento quanto a de validação devem ter diminuído a cada contagem de passos.
Crie um sampler
com o gemma.sampler.Sampler
. Ele usa o checkpoint do modelo Gemma e o tokenizador.
sampler = sampler_lib.Sampler(
transformer=model_2b,
vocab=vocab,
params=params['params'],
)
Use sampler
para verificar se o modelo pode fazer a translação. O argumento total_generation_steps
em gemma.sampler.Sampler
é o número de etapas realizadas ao gerar uma resposta. Para garantir que a entrada corresponda ao formato de treinamento, use o prefixo Translate this into French:\n
com um caractere de nova linha no final. Isso sinaliza ao modelo para começar a tradução.
sampler(
["Translate this into French:\nHello, my name is Morgane.\n"],
total_generation_steps=100,
).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]
Saiba mais
- Saiba mais sobre a biblioteca
gemma
do Google DeepMind no GitHub, que contém docstrings de módulos usados neste tutorial, comogemma.params
,gemma.transformer
egemma.sampler
. - As seguintes bibliotecas têm os próprios sites de documentação: core JAX, Flax, Chex, Optax e Orbax.
- Para conferir a documentação do tokenizer/detokenizer
sentencepiece
, consulte o repositóriosentencepiece
do Google no GitHub. - Para conferir a documentação de
kagglehub
, acesseREADME.md
no repositório dokagglehub
no GitHub da Kaggle (em inglês). - Saiba como usar modelos Gemma com a Vertex AI do Google Cloud.
- Se você estiver usando TPUs do Google Cloud (v3-8 e versões mais recentes), atualize também para o pacote
jax[tpu]
mais recente (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), reinicie o ambiente de execução e verifique se as versõesjax
ejaxlib
correspondem (!pip list | grep jax
). Isso pode impedir oRuntimeError
, que pode surgir devido à incompatibilidade de versãojaxlib
ejax
. Para mais instruções de instalação do JAX, consulte os documentos do JAX (em inglês).