Wyświetl na ai.google.dev | Uruchom w Google Colab | Otwórz w Vertex AI | Wyświetl źródło w GitHubie |
Omówienie
Gemma to rodzina lekkich, nowoczesnych, otwartych modeli językowych (LLM) opracowanych na podstawie badań i technologii Google DeepMind Gemini. Ten samouczek pokazuje, jak dostroić model przetwarzania instruktażowego Gemma 2B} z użyciem biblioteki Google DeepMind} gemma
, JAX (biblioteki obliczeń numerycznych o wysokiej wydajności), Flax (biblioteki sieci neuronowej opartej na JAX), Chex (biblioteki narzędzi do pisania niezawodnego kodu JAX i biblioteki Translation Translatora JAX, Optax Choć w notatniku nie jest używany bezpośrednio w tym notatniku, do utworzenia Gemmy użyto flaxa.
Biblioteka gemma
została napisana przy użyciu języków JAX, Flax, Orbax (oparta na języku JAX biblioteka do narzędzi treningowych, takich jak punkty kontrolne) oraz SentencePiece (biblioteka do tokenizacji i detokenizera).
Konfiguracja
1. Konfigurowanie dostępu do Kaggle dla Gemma
Aby ukończyć ten samouczek, musisz najpierw wykonać instrukcje konfiguracji opisane w artykule Konfiguracja Gemma, z którego dowiesz się, jak:
- Uzyskaj dostęp do Gemmy na kaggle.com.
- Wybierz środowisko wykonawcze Colab z wystarczającą ilością zasobów do uruchomienia modelu Gemma.
- Wygeneruj i skonfiguruj nazwę użytkownika i klucz interfejsu API Kaggle.
Po zakończeniu konfiguracji Gemma przejdź do następnej sekcji, w której możesz ustawić zmienne środowiskowe dla środowiska Colab.
2. Ustawianie zmiennych środowiskowych
Ustaw zmienne środowiskowe dla interfejsów KAGGLE_USERNAME
i KAGGLE_KEY
. Gdy pojawi się komunikat „Przyznać dostęp?”, Użytkownik wyraża zgodę na przyznanie tajnego dostępu.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. Zainstaluj bibliotekę gemma
Bezpłatna akceleracja sprzętowa Colab jest obecnie niewystarczająca do uruchomienia tego notatnika. Jeśli korzystasz z Colab Pay As You Go lub Colab Pro, kliknij Edytuj > Ustawienia notatnika > Wybierz GPU A100 > Zapisz, aby włączyć akcelerację sprzętową.
Następnie musisz zainstalować bibliotekę Google DeepMind gemma
ze strony github.com/google-deepmind/gemma
. Jeśli pojawi się błąd dotyczący resolvera zależności pip, zwykle możesz go zignorować.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. Importuj biblioteki
Ten notatnik korzysta z Flax (do obsługi sieci neuronowych), podstawowego kodu JAX, SentencePiece (do tokenizacji), Chex (biblioteki narzędzi do tworzenia niezawodnego kodu JAX) i zbiorów danych TensorFlow.
import os
import enum
import re
import string
import chex
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
Wczytaj model Gemma
Wczytaj model Gemma za pomocą parametru kagglehub.model_download
, który przyjmuje 3 argumenty:
handle
: uchwyt modelu z Kagglepath
: (opcjonalny ciąg znaków) ścieżka lokalnaforce_download
: (opcjonalna wartość logiczna) wymusza ponowne pobranie modelu.
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download... 100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
Sprawdź lokalizację wag modelu i tokenizatora, a następnie ustaw zmienne ścieżki. Katalog tokenizera znajduje się w katalogu głównym, z którego został pobrany model, a wagi modelu – w podkatalogu. Na przykład:
- Plik
tokenizer.model
będzie w lokalizacji/LOCAL/PATH/TO/gemma/flax/2b-it/2
. - Punkt kontrolny modelu będzie w:
/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it
.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model
Wczytywanie i przygotowywanie zbioru danych MTNT oraz tokenizatora Gemma
Użyjesz zbioru danych MTNT (Machine Translation of Noisy Text), który jest dostępny ze zbiorów danych TensorFlow.
Pobierz fragment zbioru danych MTNT z języka angielskiego na język francuski i wyświetl 2 przykłady. Każda próbka w zbiorze danych zawiera 2 pozycje: src
– oryginalne zdanie w języku angielskim; i dst
: odpowiednie tłumaczenie w języku francuskim.
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/35692 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...: 0%| … Generating test examples...: 0%| | 0/1020 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...: 0%| |… Generating valid examples...: 0%| | 0/811 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...: 0%| … Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data. Example 0: dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".' src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.' Example 1: dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?" src: b'Is Kameron a Little Salty About Her Lack of Air Time?'
Wczytaj tokenizer Gemma utworzony za pomocą sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Dostosuj SentencePieceProcessor
do zadania tłumaczenia z angielskiego na francuski. Ponieważ będzie dostrajać angielską część modelu Gemma, musisz wprowadzić kilka zmian, na przykład:
Prefiks danych wejściowych: dodanie wspólnego prefiksu do każdego danych wejściowych sygnalizuje zadanie translacji. Możesz na przykład użyć promptu z prefiksem takim jak
Translate this into French: [INPUT_SENTENCE]
.Sufiks początkowy translacji: dodanie sufiksu na końcu każdego promptu zapewni modelowi Gemma dokładne informacje o tym, kiedy rozpocząć proces translacji. Nowy wiersz powinien wykonać zadanie.
Tokeny modeli językowych: modele Gemma oczekują „początku sekwencji”. token na początku każdej sekwencji, więc dodaj „koniec sekwencji” na końcu każdego przykładu trenowania powinno wystarczyć.
Utwórz własny kod wokół elementu
SentencePieceProcessor
w ten sposób:
class GemmaTokenizer:
def __init__(self,
spm_processor: spm.SentencePieceProcessor):
self._spm_processor = spm_processor
@property
def pad_id(self) -> int:
"""Fast access to the pad ID."""
return self._spm_processor.pad_id()
def tokenize(self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True) -> jax.Array:
"""
The tokenization function.
Args:
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_eos: If True, add an "end of sentence" token at the end of the output
sequence.
Returns:
Tokens corresponding to the input string.
"""
int_list = [self._spm_processor.bos_id()]
int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
if add_eos:
int_list.append(self._spm_processor.eos_id())
return jnp.array(int_list, dtype=jnp.int32)
def tokenize_tf_op(self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True) -> tf.Tensor:
"""A TensorFlow operator for the tokenize function."""
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded
def to_string(self, tokens: jax.Array) -> str:
"""Convert an array of tokens to a string."""
return self._spm_processor.EncodeIds(tokens.tolist())
Wypróbuj tę funkcję, tworząc nową niestandardową instancję GemmaTokenizer
, a następnie stosując ją do niewielkiej próbki zbioru danych MTNT:
tokenizer = GemmaTokenizer(vocab)
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example,
add_eos=True)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Example 0: src: [ 2 49688 736 1280 6987 235292 108 651 2778 576 1080 104745 11982 5736 832 8995 901 780 3547 665 575 573 4589 235369 2778 235265 108] dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840 581 683 111452 581 533 235303 9776 4108 2459 679 485 235303 479 6728 579 1806 2499 709 29653 581 533 235303 101323 16054 1] Example 1: src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477 476 11709 230461 8045 3636 40268 576 4252 4897 235336 108] dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809 581 2032 69972 581 11495 1305 533 235303 65978 1654 1]
Utwórz moduł wczytujący dane dla całego zbioru danych MTNT:
@chex.dataclass(frozen=True)
class TrainingInput:
# Input tokens provided to the model.
input_tokens: jax.Array
# A mask that determines which tokens contribute to the target loss
# calculation.
target_mask: jax.Array
class DatasetSplit(enum.Enum):
TRAIN = 'train'
VALIDATION = 'valid'
class MTNTDatasetBuilder:
"""The dataset builder for the MTNT dataset."""
N_ITEMS = {DatasetSplit.TRAIN: 35_692,
DatasetSplit.VALIDATION: 811}
BUFFER_SIZE_SHUFFLE = 10_000
TRANSLATION_PREFIX = 'Translate this into French:\n'
TRANSLATION_SUFFIX = '\n'
def __init__(self,
tokenizer : GemmaTokenizer,
max_seq_len: int):
"""Constructor.
Args:
tokenizer: Gemma tokenizer to use.
max_seq_len: size of each sequence in a given batch.
"""
self._tokenizer = tokenizer
self._base_data = {
DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
}
self._max_seq_len = max_seq_len
def _tokenize_source(self, example: tf.Tensor):
"""Tokenization function for the source."""
return self._tokenizer.tokenize_tf_op(example,
prefix=self.TRANSLATION_PREFIX,
suffix=self.TRANSLATION_SUFFIX,
add_eos=False)
def _tokenize_destination(self, example: tf.Tensor):
"""Tokenization function for the French translation."""
return self._tokenizer.tokenize_tf_op(example,
add_eos=True)
def _pad_up_to_max_len(self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
"""Pad the given tensor up to sequence length of a batch."""
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(input_tensor,
[[0, to_pad]],
mode='CONSTANT',
constant_values=pad_value,
)
def _to_training_input(self,
src_tokens: jax.Array,
dst_tokens: jax.Array,
) -> TrainingInput:
"""Build a training input from a tuple of source and destination tokens."""
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)
# To prevent the model from updating based on the source (input)
# tokens, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)
# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)
# Don't want to perform the backward pass on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)
return TrainingInput(input_tokens=tokens, target_mask=mask)
def get_train_dataset(self, batch_size: int, num_epochs: int):
"""Build the training dataset."""
# Tokenize each sample.
ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst'])))
# Convert the samples to training inputs.
ds = ds.map(lambda x, y: self._to_training_input(x, y))
# Remove the samples that are too long.
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
# Shuffle the dataset.
ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
# Repeat if necessary.
ds = ds.repeat(num_epochs)
# Build batches.
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_validation_dataset(self, batch_size: int):
"""Build the validation dataset."""
# Same steps as in `get_train_dataset`, but without shuffling and no repetition.
ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst'])))
ds = ds.map(lambda x, y: self._to_training_input(x, y))
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
Wypróbuj funkcję MTNTDatasetBuilder
, ponownie tworząc niestandardową instancję GemmaTokenizer
, a następnie stosując ją do zbioru danych MTNT i pobierając próbkowanie 2 przykładów:
tokenizer = GemmaTokenizer(vocab)
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> Example 0: input_tokens: [[ 2 49688 736 1280 6987 235292 108 10924 665 12302 235341 108 2 4397 63011 1437 38696 1241 1 0] [ 2 49688 736 1280 6987 235292 108 13835 1517 235265 108 2 69875 540 19713 235265 1 0 0 0] [ 2 49688 736 1280 6987 235292 108 6956 1586 235297 235265 108 2 78368 1586 235297 235265 1 0 0]] target_mask: [[False False False False False False False False False False False False True True True True True True True False] [False False False False False False False False False False False True True True True True True False False False] [False False False False False False False False False False False False True True True True True True False False]] Example 1: input_tokens: [[ 2 49688 736 1280 6987 235292 108 18874 235341 108 2 115905 6425 1241 1 0 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 7574 3356 235341 108 2 7997 20707 1241 1 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 8703 665 235265 108 2 235338 235303 90006 20133 235265 1 0 0]] target_mask: [[False False False False False False False False False False True True True True True False False False False False] [False False False False False False False False False False False True True True True True False False False False] [False False False False False False False False False False False True True True True True True True False False]]
Konfigurowanie modelu
Zanim zaczniesz dostrajać model Gemma, musisz go skonfigurować.
Najpierw wczytaj i sformatuj punkt kontrolny modelu Gemma za pomocą metody gemma.params.load_and_format_params
:
params = params_lib.load_and_format_params(CKPT_PATH)
Aby automatycznie wczytywać prawidłową konfigurację z punktu kontrolnego modelu Gemma, użyj narzędzia gemma.transformer.TransformerConfig
. Argument cache_size
to liczba kroków czasu w pamięci podręcznej aplikacji Gemma Transformer
. Następnie utwórz instancję modelu Gemma jako model_2b
z użyciem gemma.transformer.Transformer
(dziedziczącego z flax.linen.Module
).
config_2b = transformer_lib.TransformerConfig.from_params(
params,
cache_size=30
)
model_2b = transformer_lib.Transformer(config=config_2b)
Dostrój model
W tej sekcji:
- Użyj klasy
gemma.transformer.Transformer
, aby utworzyć funkcję przekazywania dalej i utraty. - Utwórz wektory maski pozycji i uwagi dla tokenów
- Utwórz funkcję kroku trenowania za pomocą narzędzia Flax.
- Utwórz krok weryfikacji bez przebiegu wstecznego.
- Utwórz pętlę trenowania.
- Dostrój model Gemma.
Zdefiniuj klucz przekazywania i funkcję utraty za pomocą klasy gemma.transformer.Transformer
. Pole Gemma Transformer
dziedziczy dane z metody flax.linen.Module
i udostępnia 2 podstawowe metody:
init
: inicjuje parametry modelu.apply
: wykonuje funkcję__call__
modelu, korzystając z podanego zbioru parametrów.Ponieważ używasz już wytrenowanych wag Gemma, nie musisz korzystać z funkcji
init
.
def forward_and_loss_fn(params,
*,
model: transformer_lib.Transformer,
input_tokens: jax.Array, # Shape [B, L]
input_mask: jax.Array, # Shape [B, L]
positions: jax.Array, # Shape [B, L]
attention_mask: jax.Array, # [B, L, L]
) -> jax.Array:
"""The forward pass and the loss function.
Args:
params: Model's input parameters.
model: The Gemma transformer model to call.
input_tokens: Input tokens sequence, shape [B, L].
input_mask: Tokens to ignore when computing the loss, shape [B, L].
positions: Relative position of each token, shape [B, L].
attention_mask: Input attention mask, shape [B, L].
Returns:
The softmax cross-entropy loss for the next-token prediction task.
"""
# The forward pass on the input data.
# No attention cache is needed here.
logits, _ = model.apply(
params,
input_tokens,
positions,
None, # Attention cache is None.
attention_mask,
)
# Exclude the last step as it does not appear in the targets.
logits = logits[0, :-1]
# Similarly, the first token cannot be predicted.
target_tokens = input_tokens[0, 1:]
target_mask = input_mask[0, 1:]
# Convert the target labels to one-hot encoded vectors.
one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
# Don't update on unwanted tokens.
one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]
# Define the normalization factor.
norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)
# Return the negative log likelihood (NLL) loss.
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor
Klasa gemma.transformer.Transformer
wymaga wektorów attention_mask
i positions
przy każdym wektorze wejściowych. Możesz je wygenerować, tworząc funkcję niestandardową korzystającą z Transformer.build_positions_from_mask
i Transformer.make_causal_attn_mask
:
def get_attention_mask_and_positions(example: jax.Array,
pad_id : int,
)-> tuple[jax.Array, jax.Array]:
"""Builds the position and attention mask vectors from the given tokens."""
pad_mask = example != pad_id
current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
return current_token_position, attention_mask
Utwórz funkcję train_step
, która wykonuje przekazywanie wsteczne i odpowiednio aktualizuje parametry modelu, gdzie:
- Narzędzie
jax.value_and_grad
służy do oceny funkcji straty i gradientów podczas przechodzenia do przodu i do tyłu. optax.apply_updates
służy do aktualizowania parametrów.
def train_step(model: transformer_lib.Transformer,
params,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
pad_id: int,
example: TrainingInput):
"""Train step.
Args:
model: The Gemma transformer model.
params: The model's input parameters.
optimizer: The Optax optimizer to use.
opt_state: The input optimizer's state.
pad_id: ID of the pad token.
example: Input batch.
Returns:
The training loss, the updated parameters, and the updated optimizer state.
"""
# Build the position and attention mask vectors.
positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
# The forward and backward passes.
train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
attention_mask=attention_mask)
# Update the parameters.
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return train_loss, params, opt_state
Utwórz funkcję validation_step
bez przekierowania wstecznego:
def validation_step(model: transformer_lib.Transformer,
params,
pad_id: int,
example: TrainingInput,
):
positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
val_loss = forward_and_loss_fn(params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
attention_mask=attention_mask)
return val_loss
Zdefiniuj pętlę trenowania za pomocą optax.sgd
w przypadku optymalizatora SGD:
@chex.dataclass(frozen=True)
class TrainingConfig:
learning_rate: float
num_epochs: int
eval_every_n: int
batch_size: int
max_steps: int | None = None
def train_loop(
model: transformer_lib.Transformer,
params,
dataset_builder: MTNTDatasetBuilder,
training_cfg: TrainingConfig):
# Apply `jax.jit` on the training step, making the whole loop much more efficient.
compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])
# Apply `jax.jit` on the validation step.
compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])
# To save memory, use the SGD optimizer instead of the usual Adam optimizer.
# Note that for this specific example, SGD is more than enough.
optimizer = optax.sgd(training_cfg.learning_rate)
opt_state = optimizer.init(params)
# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
num_epochs=training_cfg.num_epochs)
train_ds = train_ds.as_numpy_iterator()
# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
validation_ds = validation_ds.take(50)
n_steps = 0
avg_loss=0
# A first round of the validation loss.
n_steps_eval = 0
eval_loss = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += compiled_validation_step(model,
params,
dataset_builder._tokenizer.pad_id,
val_example)
n_steps_eval += 1
print(f"Start, validation loss: {eval_loss/n_steps_eval}")
for train_example in train_ds:
train_loss, params, opt_state = compiled_train_step(model=model,
params=params,
optimizer=optimizer,
opt_state=opt_state,
pad_id=dataset_builder._tokenizer.pad_id,
example=train_example)
n_steps += 1
avg_loss += train_loss
if n_steps % training_cfg.eval_every_n == 0:
eval_loss = 0
n_steps_eval = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += compiled_validation_step(model,
params,
dataset_builder._tokenizer.pad_id,
val_example)
n_steps_eval +=1
avg_loss /= training_cfg.eval_every_n
eval_loss /= n_steps_eval
print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
avg_loss=0
if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
break
return params
Rozpocznij dostrajanie modelu Gemma w ograniczonej liczbie kroków (SEQ_SIZE
), aby mieć pewność, że mieści się w pamięci:
SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
num_epochs=1,
eval_every_n=20,
batch_size=1,
max_steps=100)
params = train_loop(model=model_2b,
params={'params': params['transformer']},
dataset_builder=dataset_builder,
training_cfg=training_cfg)
Start, validation loss: 10.647212982177734 STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336 STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848 STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459 STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975 STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245
Zarówno utrata trenowania, jak i utrata walidacji powinny zmaleć z każdą liczbą kroków.
Utwórz sampler
w gemma.sampler.Sampler
. Wykorzystuje punkt kontrolny modelu Gemma i tokenizer.
sampler = sampler_lib.Sampler(
transformer=model_2b,
vocab=vocab,
params=params['params'],
)
Użyj sampler
, aby sprawdzić, czy model może wykonywać tłumaczenie. Argument total_generation_steps
w narzędziu gemma.sampler.Sampler
to liczba kroków wykonanych podczas generowania odpowiedzi. Aby dane wejściowe były zgodne z formatem trenowania, użyj prefiksu Translate this into French:\n
ze znakiem nowego wiersza na końcu. To sygnalizuje modelowi rozpoczęcie translacji.
sampler(
["Translate this into French:\nHello, my name is Morgane.\n"],
total_generation_steps=100,
).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]
Więcej informacji
- Więcej informacji o bibliotece Google DeepMind
gemma
znajdziesz na GitHubie, która zawiera ciągi dokumentów z modułami użytymi w tym samouczku, takie jakgemma.params
,gemma.transformer
orazgemma.sampler
. - Te biblioteki mają własne witryny z dokumentacją: core JAX, Flax, Chex, Optax i Orbax.
- Dokumentację tokenizacji i detokenizatora usługi
sentencepiece
znajdziesz w repozytorium Google na GitHubiesentencepiece
. - Dokumentację usługi
kagglehub
znajdziesz w witrynieREADME.md
w repozytorium GitHubkagglehub
firmy Kaggle. - Dowiedz się, jak używać modeli Gemma w Vertex AI Google Cloud.
- Jeśli używasz jednostek Google Cloud TPU (wersja 3-8 lub nowsza), zaktualizuj też pakiet
jax[tpu]
do najnowszej wersji (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), uruchom ponownie środowisko wykonawcze i sprawdź, czy wersjejax
ijaxlib
są zgodne (!pip list | grep jax
). Może to zapobiec powstawaniu błędów typuRuntimeError
z powodu niezgodności wersjijaxlib
ijax
. Więcej instrukcji instalacji języka JAX znajdziesz w dokumentacji JAX.