前往 ai.google.dev 查看 | 在 Google Colab 中執行 | 在 Vertex AI 中開啟 | 前往 GitHub 查看原始碼 |
總覽
Gemma 是一組最先進的開放式大型語言模型,以 Google DeepMind Gemini 的研究和技術為基礎。本教學課程說明如何使用 Google DeepMind 的 gemma
程式庫、JAX (高效能數值運算程式庫)、Flax (JAX 型類神經網路程式庫)、Chex (用於編寫可靠 JAX 程式碼的 JAX 程式碼程式庫) 微調 Gemma 2B Instruct 模型中的 Gemma 2B Instruct 模型;Optax雖然這個筆記本並未直接使用 Flax,但要使用 Flax 建立 Gemma。
gemma
程式庫是使用 JAX、Flax、Orbax (用於檢查點等訓練公用程式的 JAX 程式庫) 和 SentencePiece (符記化工具/去符記化程式庫) 編寫。
設定
1. 設定 Gemma 的 Kaggle 存取權
為完成本教學課程,您必須先按照 Gemma 設定中的設定說明操作,其中將說明如何執行下列操作:
- 前往 kaggle.com 存取 Gemma。
- 選取具備足夠資源來執行 Gemma 模型的 Colab 執行階段。
- 產生並設定 Kaggle 使用者名稱和 API 金鑰。
完成 Gemma 設定後,請繼續前往下一節,設定 Colab 環境的環境變數。
2. 設定環境變數
設定 KAGGLE_USERNAME
和 KAGGLE_KEY
的環境變數。系統顯示「授予存取權?」提示訊息時,訊息,同意提供密鑰存取權。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. 安裝 gemma
程式庫
目前免費的 Colab 硬體加速功能不足,無法執行這個筆記本。如果使用 Colab Pay As You Go 或 Colab Pro,請按一下「編輯」>筆記本設定 >依序選取「A100 GPU」>按一下「儲存」即可啟用硬體加速功能。
接下來,您需要從 github.com/google-deepmind/gemma
安裝 Google DeepMind gemma
程式庫。如果收到有關「pip 依附元件解析器」的錯誤,通常可以忽略。
pip install -q git+https://github.com/google-deepmind/gemma.git
4. 匯入程式庫
這個筆記本使用 Flax (適用於類神經網路)、核心 JAX、SentencePiece (用於權杖化)、Chex (編寫可靠 JAX 程式碼的公用程式程式庫),以及 TensorFlow Datasets。
import os
import enum
import re
import string
import chex
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
載入 Gemma 模型
使用 kagglehub.model_download
載入 Gemma 模型,該模型會使用三個引數:
handle
:Kaggle 的模型控制代碼path
:(選用字串) 本機路徑force_download
:(選用布林值) 強制重新下載模型
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download... 100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
檢查模型權重的位置和符記化工具,然後設定路徑變數。符記化工具目錄會存放您下載模型的主要目錄,而模型權重則位於子目錄。例如:
tokenizer.model
檔案會在/LOCAL/PATH/TO/gemma/flax/2b-it/2
中)。- 模型查核點位於
/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it
)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model
載入並準備 MTNT 資料集和 Gemma 權杖化工具
您將使用 MTNT (Noisy Text 機器翻譯) 資料集,您可以透過 TensorFlow 資料集取得該資料集。
下載 MTNT 資料集的「English-to-French」資料集部分,然後取樣兩個範例。資料集中的每個樣本都包含兩個項目:src
:原始英文語句;和 dst
:對應的法文翻譯。
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/35692 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...: 0%| … Generating test examples...: 0%| | 0/1020 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...: 0%| |… Generating valid examples...: 0%| | 0/811 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...: 0%| … Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data. Example 0: dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".' src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.' Example 1: dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?" src: b'Is Kameron a Little Salty About Her Lack of Air Time?'
載入使用 sentencepiece.SentencePieceProcessor
建構的 Gemma 權杖化工具:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
自訂 SentencePieceProcessor
用於英文到法文的翻譯工作。由於系統會微調 Gemma 模型的英文部分,因此您必須進行一些調整,例如:
輸入前置字元:在每個輸入值中加入共同的前置字元,代表翻譯工作。舉例來說,您可以在提示中使用
Translate this into French: [INPUT_SENTENCE]
等前置字串。翻譯開始後置字串:在每個提示結尾加上後置字串,可指示 Gemma 模型開始轉譯程序的確切時間。這樣工作應該就會有一行。
語言模型符記:Gemma 模型預期「序列開頭」因此會在每個序列開頭加入一個「序列結尾」每個訓練範例結尾的符記應已足夠
在
SentencePieceProcessor
周圍建構自訂包裝函式,如下所示:
class GemmaTokenizer:
def __init__(self,
spm_processor: spm.SentencePieceProcessor):
self._spm_processor = spm_processor
@property
def pad_id(self) -> int:
"""Fast access to the pad ID."""
return self._spm_processor.pad_id()
def tokenize(self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True) -> jax.Array:
"""
The tokenization function.
Args:
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_eos: If True, add an "end of sentence" token at the end of the output
sequence.
Returns:
Tokens corresponding to the input string.
"""
int_list = [self._spm_processor.bos_id()]
int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
if add_eos:
int_list.append(self._spm_processor.eos_id())
return jnp.array(int_list, dtype=jnp.int32)
def tokenize_tf_op(self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True) -> tf.Tensor:
"""A TensorFlow operator for the tokenize function."""
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded
def to_string(self, tokens: jax.Array) -> str:
"""Convert an array of tokens to a string."""
return self._spm_processor.EncodeIds(tokens.tolist())
如要試用這項功能,請將新的自訂 GemmaTokenizer
執行個體化,然後套用至 MTNT 資料集的小樣本:
tokenizer = GemmaTokenizer(vocab)
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example,
add_eos=True)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Example 0: src: [ 2 49688 736 1280 6987 235292 108 651 2778 576 1080 104745 11982 5736 832 8995 901 780 3547 665 575 573 4589 235369 2778 235265 108] dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840 581 683 111452 581 533 235303 9776 4108 2459 679 485 235303 479 6728 579 1806 2499 709 29653 581 533 235303 101323 16054 1] Example 1: src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477 476 11709 230461 8045 3636 40268 576 4252 4897 235336 108] dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809 581 2032 69972 581 11495 1305 533 235303 65978 1654 1]
建構整個 MTNT 資料集的資料載入器:
@chex.dataclass(frozen=True)
class TrainingInput:
# Input tokens provided to the model.
input_tokens: jax.Array
# A mask that determines which tokens contribute to the target loss
# calculation.
target_mask: jax.Array
class DatasetSplit(enum.Enum):
TRAIN = 'train'
VALIDATION = 'valid'
class MTNTDatasetBuilder:
"""The dataset builder for the MTNT dataset."""
N_ITEMS = {DatasetSplit.TRAIN: 35_692,
DatasetSplit.VALIDATION: 811}
BUFFER_SIZE_SHUFFLE = 10_000
TRANSLATION_PREFIX = 'Translate this into French:\n'
TRANSLATION_SUFFIX = '\n'
def __init__(self,
tokenizer : GemmaTokenizer,
max_seq_len: int):
"""Constructor.
Args:
tokenizer: Gemma tokenizer to use.
max_seq_len: size of each sequence in a given batch.
"""
self._tokenizer = tokenizer
self._base_data = {
DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
}
self._max_seq_len = max_seq_len
def _tokenize_source(self, example: tf.Tensor):
"""Tokenization function for the source."""
return self._tokenizer.tokenize_tf_op(example,
prefix=self.TRANSLATION_PREFIX,
suffix=self.TRANSLATION_SUFFIX,
add_eos=False)
def _tokenize_destination(self, example: tf.Tensor):
"""Tokenization function for the French translation."""
return self._tokenizer.tokenize_tf_op(example,
add_eos=True)
def _pad_up_to_max_len(self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
"""Pad the given tensor up to sequence length of a batch."""
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(input_tensor,
[[0, to_pad]],
mode='CONSTANT',
constant_values=pad_value,
)
def _to_training_input(self,
src_tokens: jax.Array,
dst_tokens: jax.Array,
) -> TrainingInput:
"""Build a training input from a tuple of source and destination tokens."""
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)
# To prevent the model from updating based on the source (input)
# tokens, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)
# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)
# Don't want to perform the backward pass on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)
return TrainingInput(input_tokens=tokens, target_mask=mask)
def get_train_dataset(self, batch_size: int, num_epochs: int):
"""Build the training dataset."""
# Tokenize each sample.
ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst'])))
# Convert the samples to training inputs.
ds = ds.map(lambda x, y: self._to_training_input(x, y))
# Remove the samples that are too long.
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
# Shuffle the dataset.
ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
# Repeat if necessary.
ds = ds.repeat(num_epochs)
# Build batches.
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_validation_dataset(self, batch_size: int):
"""Build the validation dataset."""
# Same steps as in `get_train_dataset`, but without shuffling and no repetition.
ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst'])))
ds = ds.map(lambda x, y: self._to_training_input(x, y))
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
再次將自訂 GemmaTokenizer
執行個體化,然後套用至 MTNT 資料集,然後取樣兩個範例,以試用 MTNTDatasetBuilder
:
tokenizer = GemmaTokenizer(vocab)
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> Example 0: input_tokens: [[ 2 49688 736 1280 6987 235292 108 10924 665 12302 235341 108 2 4397 63011 1437 38696 1241 1 0] [ 2 49688 736 1280 6987 235292 108 13835 1517 235265 108 2 69875 540 19713 235265 1 0 0 0] [ 2 49688 736 1280 6987 235292 108 6956 1586 235297 235265 108 2 78368 1586 235297 235265 1 0 0]] target_mask: [[False False False False False False False False False False False False True True True True True True True False] [False False False False False False False False False False False True True True True True True False False False] [False False False False False False False False False False False False True True True True True True False False]] Example 1: input_tokens: [[ 2 49688 736 1280 6987 235292 108 18874 235341 108 2 115905 6425 1241 1 0 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 7574 3356 235341 108 2 7997 20707 1241 1 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 8703 665 235265 108 2 235338 235303 90006 20133 235265 1 0 0]] target_mask: [[False False False False False False False False False False True True True True True False False False False False] [False False False False False False False False False False False True True True True True False False False False] [False False False False False False False False False False False True True True True True True True False False]]
設定模型
您必須先設定 Gemma 模型,才能開始微調。
首先,使用 gemma.params.load_and_format_params
方法載入 Gemma 模型查核點並設定格式:
params = params_lib.load_and_format_params(CKPT_PATH)
如要自動從 Gemma 模型查核點載入正確的設定,請使用 gemma.transformer.TransformerConfig
。cache_size
引數是 Gemma Transformer
快取中的時步數。之後,請使用 gemma.transformer.Transformer
(繼承自 flax.linen.Module
) 將 Gemma 模型例項化為 model_2b
。
config_2b = transformer_lib.TransformerConfig.from_params(
params,
cache_size=30
)
model_2b = transformer_lib.Transformer(config=config_2b)
微調模型
在本節中,您將進行以下作業:
- 使用
gemma.transformer.Transformer
類別建立正向傳遞和損失函式。 - 建立符記的位置和注意力遮罩向量
- 使用 Flax 建構訓練步數函式。
- 建立驗證步驟而不使用反向傳遞。
- 建立訓練迴圈。
- 微調 Gemma 模型。
使用 gemma.transformer.Transformer
類別定義正向傳遞和損失函式。Gemma Transformer
繼承自 flax.linen.Module
,並提供兩種基本方法:
init
:初始化模型的參數。apply
:使用指定的一組參數執行模型的__call__
函式。由於您使用的是預先訓練的 Gemma 權重,因此不需要使用
init
函式。
def forward_and_loss_fn(params,
*,
model: transformer_lib.Transformer,
input_tokens: jax.Array, # Shape [B, L]
input_mask: jax.Array, # Shape [B, L]
positions: jax.Array, # Shape [B, L]
attention_mask: jax.Array, # [B, L, L]
) -> jax.Array:
"""The forward pass and the loss function.
Args:
params: Model's input parameters.
model: The Gemma transformer model to call.
input_tokens: Input tokens sequence, shape [B, L].
input_mask: Tokens to ignore when computing the loss, shape [B, L].
positions: Relative position of each token, shape [B, L].
attention_mask: Input attention mask, shape [B, L].
Returns:
The softmax cross-entropy loss for the next-token prediction task.
"""
# The forward pass on the input data.
# No attention cache is needed here.
logits, _ = model.apply(
params,
input_tokens,
positions,
None, # Attention cache is None.
attention_mask,
)
# Exclude the last step as it does not appear in the targets.
logits = logits[0, :-1]
# Similarly, the first token cannot be predicted.
target_tokens = input_tokens[0, 1:]
target_mask = input_mask[0, 1:]
# Convert the target labels to one-hot encoded vectors.
one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
# Don't update on unwanted tokens.
one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]
# Define the normalization factor.
norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)
# Return the negative log likelihood (NLL) loss.
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor
gemma.transformer.Transformer
類別需要在每項輸入內容旁邊使用 attention_mask
和 positions
向量。您可以建立使用 Transformer.build_positions_from_mask
和 Transformer.make_causal_attn_mask
的自訂函式來產生以下結果:
def get_attention_mask_and_positions(example: jax.Array,
pad_id : int,
)-> tuple[jax.Array, jax.Array]:
"""Builds the position and attention mask vectors from the given tokens."""
pad_mask = example != pad_id
current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
return current_token_position, attention_mask
建構 train_step
函式,執行反向傳遞,並據此更新模型的參數,其中:
jax.value_and_grad
用於評估向前和向後傳遞的損失函式和梯度。optax.apply_updates
代表要更新參數。
def train_step(model: transformer_lib.Transformer,
params,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
pad_id: int,
example: TrainingInput):
"""Train step.
Args:
model: The Gemma transformer model.
params: The model's input parameters.
optimizer: The Optax optimizer to use.
opt_state: The input optimizer's state.
pad_id: ID of the pad token.
example: Input batch.
Returns:
The training loss, the updated parameters, and the updated optimizer state.
"""
# Build the position and attention mask vectors.
positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
# The forward and backward passes.
train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
attention_mask=attention_mask)
# Update the parameters.
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return train_loss, params, opt_state
建構 validation_step
函式,但不使用反向傳遞:
def validation_step(model: transformer_lib.Transformer,
params,
pad_id: int,
example: TrainingInput,
):
positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
val_loss = forward_and_loss_fn(params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
attention_mask=attention_mask)
return val_loss
使用 SGD 最佳化工具的 optax.sgd
定義訓練迴圈:
@chex.dataclass(frozen=True)
class TrainingConfig:
learning_rate: float
num_epochs: int
eval_every_n: int
batch_size: int
max_steps: int | None = None
def train_loop(
model: transformer_lib.Transformer,
params,
dataset_builder: MTNTDatasetBuilder,
training_cfg: TrainingConfig):
# Apply `jax.jit` on the training step, making the whole loop much more efficient.
compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])
# Apply `jax.jit` on the validation step.
compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])
# To save memory, use the SGD optimizer instead of the usual Adam optimizer.
# Note that for this specific example, SGD is more than enough.
optimizer = optax.sgd(training_cfg.learning_rate)
opt_state = optimizer.init(params)
# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
num_epochs=training_cfg.num_epochs)
train_ds = train_ds.as_numpy_iterator()
# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
validation_ds = validation_ds.take(50)
n_steps = 0
avg_loss=0
# A first round of the validation loss.
n_steps_eval = 0
eval_loss = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += compiled_validation_step(model,
params,
dataset_builder._tokenizer.pad_id,
val_example)
n_steps_eval += 1
print(f"Start, validation loss: {eval_loss/n_steps_eval}")
for train_example in train_ds:
train_loss, params, opt_state = compiled_train_step(model=model,
params=params,
optimizer=optimizer,
opt_state=opt_state,
pad_id=dataset_builder._tokenizer.pad_id,
example=train_example)
n_steps += 1
avg_loss += train_loss
if n_steps % training_cfg.eval_every_n == 0:
eval_loss = 0
n_steps_eval = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += compiled_validation_step(model,
params,
dataset_builder._tokenizer.pad_id,
val_example)
n_steps_eval +=1
avg_loss /= training_cfg.eval_every_n
eval_loss /= n_steps_eval
print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
avg_loss=0
if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
break
return params
開始以有限的步驟 (SEQ_SIZE
) 微調 Gemma 模型,確保這符合記憶體大小:
SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
num_epochs=1,
eval_every_n=20,
batch_size=1,
max_steps=100)
params = train_loop(model=model_2b,
params={'params': params['transformer']},
dataset_builder=dataset_builder,
training_cfg=training_cfg)
Start, validation loss: 10.647212982177734 STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336 STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848 STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459 STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975 STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245
每個步數都會導致訓練損失和驗證損失都會下降。
使用 gemma.sampler.Sampler
建立 sampler
。使用 Gemma 模型查核點和符記化工具。
sampler = sampler_lib.Sampler(
transformer=model_2b,
vocab=vocab,
params=params['params'],
)
使用 sampler
檢查模型是否能執行翻譯。gemma.sampler.Sampler
中的 total_generation_steps
引數是產生回應時執行的步驟數。為確保輸入內容與訓練格式相符,請使用前置字串 Translate this into French:\n
,並在結尾加上換行字元。這會指示模型開始翻譯。
sampler(
["Translate this into French:\nHello, my name is Morgane.\n"],
total_generation_steps=100,
).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]
瞭解詳情
- 您可以進一步瞭解 GitHub 上的 Google DeepMind 程式庫
gemma
,其中包含您在本教學課程中使用的模組 docstring,例如gemma.params
、gemma.transformer
和gemma.sampler
。 - 下列程式庫都有專屬的說明文件網站:核心 JAX、Flax、Chex、Optax 以及 Orbax。
- 如需
sentencepiece
權杖化工具/解碼器說明文件,請前往 Google 的sentencepiece
GitHub 存放區。 - 如需
kagglehub
說明文件,請查看 Kagglekagglehub
GitHub 存放區中的README.md
。 - 瞭解如何搭配使用 Gemma 模型與 Google Cloud Vertex AI。
- 如果您使用的是 Google Cloud TPU (v3-8 及以上版本),請務必一併更新至最新的
jax[tpu]
套件 (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)、重新啟動執行階段,並檢查jax
和jaxlib
版本是否相符 (!pip list | grep jax
)。這麼做可以避免因jaxlib
和jax
版本不符而可能發生的RuntimeError
。如需更多 JAX 安裝操作說明,請參閱 JAX 文件。