ai.google.dev で表示 | Google Colab で実行 | Vertex AI で開く | GitHub のソースを表示 |
概要
Gemma は、Google DeepMind の Gemini の研究とテクノロジーに基づく、軽量で最先端のオープン大規模言語モデルのファミリーです。このチュートリアルでは、Google DeepMind の gemma
ライブラリ、JAX(高性能数値計算ライブラリ)、Flax(JAX ベースのニューラル ネットワーク ライブラリ)、Chex(信頼性の高い JAX コードを記述するためのユーティリティのライブラリ)、Optax2(テキスト処理の信頼できる JAX1)ライブラリのライブラリ)、NT(機械処理最適化のユーティリティのライブラリ)Optax(ML データセット最適化のライブラリ)(ML データセット)を使用して、英語 - フランス語翻訳タスクの Gemma 2B Instruct モデルをファインチューニングする方法について説明します。このノートブックでは Flax は直接使用されていませんが、Gemma の作成には Flax が使用されました。
gemma
ライブラリは、JAX、Flax、Orbax(チェックポインティングなどのユーティリティをトレーニングするための JAX ベースのライブラリ)、SentencePiece(トークナイザー/デトークナイザー ライブラリ)を使用して作成されています。
セットアップ
1. Gemma 用に Kaggle のアクセスを設定する
このチュートリアルを完了するには、まず、Gemma の設定に記載されている設定手順を実施する必要があります。この手順では、以下を行う方法について説明します。
- kaggle.com で Gemma にアクセスしてください。
- Gemma モデルを実行するのに十分なリソースがある Colab ランタイムを選択します。
- Kaggle のユーザー名と API キーを生成して構成します。
Gemma の設定が完了したら、次のセクションに進み、Colab 環境の環境変数を設定します。
2. 環境変数を設定する
KAGGLE_USERNAME
と KAGGLE_KEY
の環境変数を設定します。[アクセスを許可しますか?] というメッセージが表示されたら、シークレットへのアクセスを提供することに合意します。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. gemma
ライブラリをインストールする
現在、このノートブックを実行するには、無料の Colab ハードウェア アクセラレーションでは不十分です。Colab 従量課金制または Colab Pro を使用している場合は、[編集] をクリックします。ノートブックの設定 >[A100 GPU] を選択 >ハードウェア アクセラレーションを有効にするには、[保存] をクリックします。
次に、github.com/google-deepmind/gemma
から Google DeepMind gemma
ライブラリをインストールする必要があります。「pip の依存関係リゾルバ」に関するエラーが発生した場合、通常は無視できます。
pip install -q git+https://github.com/google-deepmind/gemma.git
4. ライブラリをインポートする
このノートブックでは、Flax(ニューラル ネットワーク用)、コアの JAX、SentencePiece(トークン化用)、Chex(信頼性の高い JAX コードを作成するためのユーティリティのライブラリ)、TensorFlow データセットを使用します。
import os
import enum
import re
import string
import chex
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
Gemma モデルを読み込む
kagglehub.model_download
を使用して Gemma モデルを読み込みます。これは、次の 3 つの引数を取ります。
handle
: Kaggle のモデルハンドルpath
: (省略可)ローカルパスforce_download
: (省略可のブール値)モデルの再ダウンロードを強制的に行います
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download... 100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
モデルの重みとトークナイザの場所を確認して、パス変数を設定します。トークナイザー ディレクトリはモデルをダウンロードしたメイン ディレクトリにありますが、モデルの重みはサブディレクトリにあります。例:
tokenizer.model
ファイルは/LOCAL/PATH/TO/gemma/flax/2b-it/2
にあります)。- モデルのチェックポイントは
/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it
にあります)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model
MTNT データセットと Gemma トークナイザを読み込んで準備する
MTNT(Machine Translation of Noisy Text)データセットを使用します。このデータセットは TensorFlow Datasets から入手できます。
MTNT データセットの英語からフランス語のデータセット部分をダウンロードして、2 つのサンプルをサンプリングします。データセットの各サンプルには、次の 2 つのエントリが含まれています。src
: 元の英語のセンテンス。dst
: 対応するフランス語の翻訳。
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/35692 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...: 0%| … Generating test examples...: 0%| | 0/1020 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...: 0%| |… Generating valid examples...: 0%| | 0/811 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...: 0%| … Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data. Example 0: dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".' src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.' Example 1: dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?" src: b'Is Kameron a Little Salty About Her Lack of Air Time?'
sentencepiece.SentencePieceProcessor
を使用して作成された Gemma トークナイザを読み込みます。
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
英語からフランス語への翻訳タスクの SentencePieceProcessor
をカスタマイズします。Gemma モデルの英語部分をファインチューニングするため、次のような調整を行う必要があります。
入力接頭辞: 各入力に共通の接頭辞を追加すると、変換タスクが通知されます。たとえば、
Translate this into French: [INPUT_SENTENCE]
のような接頭辞を持つプロンプトを使用できます。変換開始接尾辞: 各プロンプトの末尾に接尾辞を追加すると、Gemma モデルに翻訳プロセスを開始する正確なタイミングを指示できます。新しい行で処理を行います。
言語モデル トークン: Gemma モデルは「シーケンスの開始」を必要とします。各シーケンスの最初に出現する可能性があるため、「シーケンスの終わり」にトークンで十分です。
次のように、
SentencePieceProcessor
の周囲にカスタム ラッパーを作成します。
class GemmaTokenizer:
def __init__(self,
spm_processor: spm.SentencePieceProcessor):
self._spm_processor = spm_processor
@property
def pad_id(self) -> int:
"""Fast access to the pad ID."""
return self._spm_processor.pad_id()
def tokenize(self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True) -> jax.Array:
"""
The tokenization function.
Args:
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_eos: If True, add an "end of sentence" token at the end of the output
sequence.
Returns:
Tokens corresponding to the input string.
"""
int_list = [self._spm_processor.bos_id()]
int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
if add_eos:
int_list.append(self._spm_processor.eos_id())
return jnp.array(int_list, dtype=jnp.int32)
def tokenize_tf_op(self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True) -> tf.Tensor:
"""A TensorFlow operator for the tokenize function."""
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded
def to_string(self, tokens: jax.Array) -> str:
"""Convert an array of tokens to a string."""
return self._spm_processor.EncodeIds(tokens.tolist())
新しいカスタム GemmaTokenizer
をインスタンス化し、MTNT データセットの小規模なサンプルに適用して試してみましょう。
tokenizer = GemmaTokenizer(vocab)
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example,
add_eos=True)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Example 0: src: [ 2 49688 736 1280 6987 235292 108 651 2778 576 1080 104745 11982 5736 832 8995 901 780 3547 665 575 573 4589 235369 2778 235265 108] dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840 581 683 111452 581 533 235303 9776 4108 2459 679 485 235303 479 6728 579 1806 2499 709 29653 581 533 235303 101323 16054 1] Example 1: src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477 476 11709 230461 8045 3636 40268 576 4252 4897 235336 108] dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809 581 2032 69972 581 11495 1305 533 235303 65978 1654 1]
MTNT データセット全体のデータローダを作成します。
@chex.dataclass(frozen=True)
class TrainingInput:
# Input tokens provided to the model.
input_tokens: jax.Array
# A mask that determines which tokens contribute to the target loss
# calculation.
target_mask: jax.Array
class DatasetSplit(enum.Enum):
TRAIN = 'train'
VALIDATION = 'valid'
class MTNTDatasetBuilder:
"""The dataset builder for the MTNT dataset."""
N_ITEMS = {DatasetSplit.TRAIN: 35_692,
DatasetSplit.VALIDATION: 811}
BUFFER_SIZE_SHUFFLE = 10_000
TRANSLATION_PREFIX = 'Translate this into French:\n'
TRANSLATION_SUFFIX = '\n'
def __init__(self,
tokenizer : GemmaTokenizer,
max_seq_len: int):
"""Constructor.
Args:
tokenizer: Gemma tokenizer to use.
max_seq_len: size of each sequence in a given batch.
"""
self._tokenizer = tokenizer
self._base_data = {
DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
}
self._max_seq_len = max_seq_len
def _tokenize_source(self, example: tf.Tensor):
"""Tokenization function for the source."""
return self._tokenizer.tokenize_tf_op(example,
prefix=self.TRANSLATION_PREFIX,
suffix=self.TRANSLATION_SUFFIX,
add_eos=False)
def _tokenize_destination(self, example: tf.Tensor):
"""Tokenization function for the French translation."""
return self._tokenizer.tokenize_tf_op(example,
add_eos=True)
def _pad_up_to_max_len(self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
"""Pad the given tensor up to sequence length of a batch."""
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(input_tensor,
[[0, to_pad]],
mode='CONSTANT',
constant_values=pad_value,
)
def _to_training_input(self,
src_tokens: jax.Array,
dst_tokens: jax.Array,
) -> TrainingInput:
"""Build a training input from a tuple of source and destination tokens."""
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)
# To prevent the model from updating based on the source (input)
# tokens, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)
# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)
# Don't want to perform the backward pass on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)
return TrainingInput(input_tokens=tokens, target_mask=mask)
def get_train_dataset(self, batch_size: int, num_epochs: int):
"""Build the training dataset."""
# Tokenize each sample.
ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst'])))
# Convert the samples to training inputs.
ds = ds.map(lambda x, y: self._to_training_input(x, y))
# Remove the samples that are too long.
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
# Shuffle the dataset.
ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
# Repeat if necessary.
ds = ds.repeat(num_epochs)
# Build batches.
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_validation_dataset(self, batch_size: int):
"""Build the validation dataset."""
# Same steps as in `get_train_dataset`, but without shuffling and no repetition.
ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst'])))
ds = ds.map(lambda x, y: self._to_training_input(x, y))
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
カスタムの GemmaTokenizer
を再度インスタンス化し、MTNT データセットに適用して 2 つの例をサンプリングして、MTNTDatasetBuilder
を試します。
tokenizer = GemmaTokenizer(vocab)
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> Example 0: input_tokens: [[ 2 49688 736 1280 6987 235292 108 10924 665 12302 235341 108 2 4397 63011 1437 38696 1241 1 0] [ 2 49688 736 1280 6987 235292 108 13835 1517 235265 108 2 69875 540 19713 235265 1 0 0 0] [ 2 49688 736 1280 6987 235292 108 6956 1586 235297 235265 108 2 78368 1586 235297 235265 1 0 0]] target_mask: [[False False False False False False False False False False False False True True True True True True True False] [False False False False False False False False False False False True True True True True True False False False] [False False False False False False False False False False False False True True True True True True False False]] Example 1: input_tokens: [[ 2 49688 736 1280 6987 235292 108 18874 235341 108 2 115905 6425 1241 1 0 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 7574 3356 235341 108 2 7997 20707 1241 1 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 8703 665 235265 108 2 235338 235303 90006 20133 235265 1 0 0]] target_mask: [[False False False False False False False False False False True True True True True False False False False False] [False False False False False False False False False False False True True True True True False False False False] [False False False False False False False False False False False True True True True True True True False False]]
モデルを構成する
Gemma モデルのファインチューニングを開始する前に、Gemma モデルを構成する必要があります。
まず、gemma.params.load_and_format_params
メソッドで Gemma モデルのチェックポイントを読み込んでフォーマットします。
params = params_lib.load_and_format_params(CKPT_PATH)
Gemma モデルのチェックポイントから正しい構成を自動的に読み込むには、gemma.transformer.TransformerConfig
を使用します。cache_size
引数は、Gemma の Transformer
キャッシュ内のタイムステップの数です。その後、gemma.transformer.Transformer
(flax.linen.Module
から継承)を使用して、Gemma モデルを model_2b
としてインスタンス化します。
config_2b = transformer_lib.TransformerConfig.from_params(
params,
cache_size=30
)
model_2b = transformer_lib.Transformer(config=config_2b)
モデルをファインチューニングする
このセクションでは、次の作業を行います。
gemma.transformer.Transformer
クラスを使用して、フォワードパスと損失関数を作成します。- トークンの位置マスク ベクトルとアテンション マスク ベクトルを作成する
- Flax を使用してトレーニングのステップ関数を作成する。
- バックワード パスを使用せずに検証ステップをビルドします。
- トレーニング ループを作成する。
- Gemma モデルをファインチューニングします。
gemma.transformer.Transformer
クラスを使用してフォワードパスと損失関数を定義します。Gemma の Transformer
は flax.linen.Module
から継承されており、次の 2 つの基本的なメソッドを提供します。
init
: モデルのパラメータを初期化します。apply
: 指定されたパラメータ セットを使用してモデルの__call__
関数を実行します。事前トレーニング済みの Gemma の重みを使用しているため、
init
関数を使用する必要はありません。
def forward_and_loss_fn(params,
*,
model: transformer_lib.Transformer,
input_tokens: jax.Array, # Shape [B, L]
input_mask: jax.Array, # Shape [B, L]
positions: jax.Array, # Shape [B, L]
attention_mask: jax.Array, # [B, L, L]
) -> jax.Array:
"""The forward pass and the loss function.
Args:
params: Model's input parameters.
model: The Gemma transformer model to call.
input_tokens: Input tokens sequence, shape [B, L].
input_mask: Tokens to ignore when computing the loss, shape [B, L].
positions: Relative position of each token, shape [B, L].
attention_mask: Input attention mask, shape [B, L].
Returns:
The softmax cross-entropy loss for the next-token prediction task.
"""
# The forward pass on the input data.
# No attention cache is needed here.
logits, _ = model.apply(
params,
input_tokens,
positions,
None, # Attention cache is None.
attention_mask,
)
# Exclude the last step as it does not appear in the targets.
logits = logits[0, :-1]
# Similarly, the first token cannot be predicted.
target_tokens = input_tokens[0, 1:]
target_mask = input_mask[0, 1:]
# Convert the target labels to one-hot encoded vectors.
one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
# Don't update on unwanted tokens.
one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]
# Define the normalization factor.
norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)
# Return the negative log likelihood (NLL) loss.
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor
gemma.transformer.Transformer
クラスには、各入力とともに attention_mask
と positions
ベクトルが必要です。これらを生成するには、Transformer.build_positions_from_mask
と Transformer.make_causal_attn_mask
を使用するカスタム関数を作成します。
def get_attention_mask_and_positions(example: jax.Array,
pad_id : int,
)-> tuple[jax.Array, jax.Array]:
"""Builds the position and attention mask vectors from the given tokens."""
pad_mask = example != pad_id
current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
return current_token_position, attention_mask
バックワード パスを実行し、それに応じてモデルのパラメータを更新する train_step
関数を作成します。ここで、
jax.value_and_grad
は、フォワードパスとバックワード パスの間の損失関数と勾配を評価するために使用します。optax.apply_updates
はパラメータを更新するためのものです。
def train_step(model: transformer_lib.Transformer,
params,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
pad_id: int,
example: TrainingInput):
"""Train step.
Args:
model: The Gemma transformer model.
params: The model's input parameters.
optimizer: The Optax optimizer to use.
opt_state: The input optimizer's state.
pad_id: ID of the pad token.
example: Input batch.
Returns:
The training loss, the updated parameters, and the updated optimizer state.
"""
# Build the position and attention mask vectors.
positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
# The forward and backward passes.
train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
attention_mask=attention_mask)
# Update the parameters.
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return train_loss, params, opt_state
バックワード パスを使用せずに validation_step
関数を作成します。
def validation_step(model: transformer_lib.Transformer,
params,
pad_id: int,
example: TrainingInput,
):
positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
val_loss = forward_and_loss_fn(params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
attention_mask=attention_mask)
return val_loss
SGD オプティマイザーに optax.sgd
を使用して、トレーニング ループを定義します。
@chex.dataclass(frozen=True)
class TrainingConfig:
learning_rate: float
num_epochs: int
eval_every_n: int
batch_size: int
max_steps: int | None = None
def train_loop(
model: transformer_lib.Transformer,
params,
dataset_builder: MTNTDatasetBuilder,
training_cfg: TrainingConfig):
# Apply `jax.jit` on the training step, making the whole loop much more efficient.
compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])
# Apply `jax.jit` on the validation step.
compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])
# To save memory, use the SGD optimizer instead of the usual Adam optimizer.
# Note that for this specific example, SGD is more than enough.
optimizer = optax.sgd(training_cfg.learning_rate)
opt_state = optimizer.init(params)
# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
num_epochs=training_cfg.num_epochs)
train_ds = train_ds.as_numpy_iterator()
# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
validation_ds = validation_ds.take(50)
n_steps = 0
avg_loss=0
# A first round of the validation loss.
n_steps_eval = 0
eval_loss = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += compiled_validation_step(model,
params,
dataset_builder._tokenizer.pad_id,
val_example)
n_steps_eval += 1
print(f"Start, validation loss: {eval_loss/n_steps_eval}")
for train_example in train_ds:
train_loss, params, opt_state = compiled_train_step(model=model,
params=params,
optimizer=optimizer,
opt_state=opt_state,
pad_id=dataset_builder._tokenizer.pad_id,
example=train_example)
n_steps += 1
avg_loss += train_loss
if n_steps % training_cfg.eval_every_n == 0:
eval_loss = 0
n_steps_eval = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += compiled_validation_step(model,
params,
dataset_builder._tokenizer.pad_id,
val_example)
n_steps_eval +=1
avg_loss /= training_cfg.eval_every_n
eval_loss /= n_steps_eval
print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
avg_loss=0
if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
break
return params
限られたステップ(SEQ_SIZE
)で Gemma モデルのファインチューニングを開始し、メモリに収まることを確認します。
SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
num_epochs=1,
eval_every_n=20,
batch_size=1,
max_steps=100)
params = train_loop(model=model_2b,
params={'params': params['transformer']},
dataset_builder=dataset_builder,
training_cfg=training_cfg)
Start, validation loss: 10.647212982177734 STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336 STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848 STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459 STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975 STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245
トレーニングの損失と検証の損失は、ステップ数が増えるごとに減少しているはずです。
gemma.sampler.Sampler
で sampler
を作成します。Gemma モデルのチェックポイントとトークナイザを使用します。
sampler = sampler_lib.Sampler(
transformer=model_2b,
vocab=vocab,
params=params['params'],
)
sampler
を使用して、モデルが翻訳を行えるかどうかを確認します。gemma.sampler.Sampler
の total_generation_steps
引数は、レスポンスの生成時に実行されるステップ数です。入力をトレーニング形式と一致させるには、接頭辞 Translate this into French:\n
の末尾に改行文字を使用します。これにより、モデルに翻訳を開始するように指示できます。
sampler(
["Translate this into French:\nHello, my name is Morgane.\n"],
total_generation_steps=100,
).text
["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]
その他の情報
- Google DeepMind の GitHub の
gemma
ライブラリの詳細を確認できます。これには、このチュートリアルで使用したモジュールの docstring が含まれています(gemma.params
、gemma.transformer
、gemma.sampler
。 - core JAX、Flax、Chex、Optax、Orbax という独自のドキュメント サイトがあります。
sentencepiece
トークナイザーとデトークナイザーのドキュメントについては、Google のsentencepiece
GitHub リポジトリをご覧ください。kagglehub
のドキュメントについては、Kaggle のkagglehub
GitHub リポジトリでREADME.md
をご覧ください。- Google Cloud Vertex AI で Gemma モデルを使用する方法を学習する。
- Google Cloud TPU(v3-8 以降)を使用している場合は、最新の
jax[tpu]
パッケージ(!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)に更新し、ランタイムを再起動して、jax
とjaxlib
のバージョンが一致している(!pip list | grep jax
)ことを確認してください。これにより、jaxlib
とjax
のバージョンの不一致が原因で発生するRuntimeError
を防ぐことができます。JAX のインストール手順の詳細については、JAX のドキュメントをご覧ください。