Keras를 사용하여 Gemma를 사용한 분산 조정

ai.google.dev에서 보기 Kaggle에서 실행 Vertex AI에서 열기 GitHub에서 소스 보기

개요

Gemma는 Google Gemini 모델을 만드는 데 사용된 연구와 기술을 바탕으로 개발된 경량의 최첨단 개방형 모델 제품군입니다. Gemma를 특정 요구에 맞게 더욱 세밀하게 조정할 수 있습니다. 하지만 Gemma와 같은 대규모 언어 모델은 크기가 매우 클 수 있으며, 일부는 미세 조정을 위한 싱크 액셀러레이터에 적합하지 않을 수 있습니다. 이 경우 다음과 같은 두 가지 일반적인 방법으로 미세 조정할 수 있습니다.

  1. PEFT (매개변수 효율적 미세 조정): 일부 충실도를 희생하여 유효 모델 크기를 축소합니다. LoRA는 이 카테고리에 속합니다. LoRA를 사용하여 Keras에서 Gemma 모델 미세 조정 가이드에서는 단일 GPU에서 KerasNLP를 사용하여 LoRA로 Gemma 2B 모델 gemma_2b_en을 미세 조정하는 방법을 보여줍니다.
  2. 모델 동시 로드로 전체 매개변수 미세 조정 모델 동시 로드는 단일 모델의 가중치를 여러 기기에 분산하고 수평 확장을 지원합니다. Keras 가이드에서 분산 학습에 대해 자세히 알아볼 수 있습니다.

이 튜토리얼에서는 Keras를 JAX 백엔드와 함께 사용하여 LoRA를 통해 Gemma 7B 모델을 미세 조정하고 Google의 Tensor Processing Unit (TPU)에서 Model-parallism 분산 학습을 수행하는 방법을 안내합니다. 느리지만 더 정확한 전체 매개변수를 조정하려면 이 튜토리얼에서 LoRA를 사용 중지할 수 있습니다.

가속기 사용

기술적으로 이 가이드에서는 TPU나 GPU를 사용할 수 있습니다.

TPU 환경에 관한 참고사항

Google은 TPU를 제공하는 3가지 제품을 보유하고 있습니다.

  • Colab은 TPU v2를 제공하지만 이 가이드에서는 충분하지 않습니다.
  • Kaggle은 TPU v3를 무료로 제공하며 이 튜토리얼에 적합합니다.
  • Cloud TPU는 TPU v3 이상 세대를 제공합니다. 설정 방법은 다음과 같습니다.
    1. TPU VM 만들기
    2. 원하는 Jupyter 서버 포트에 SSH 포트 전달 설정
    3. Jupyter를 설치하고 TPU VM에서 시작한 다음 '로컬 런타임에 연결'을 통해 Colab에 연결

다중 GPU 설정 참고 사항

이 튜토리얼에서는 TPU 사용 사례를 중점적으로 다루지만 다중 GPU 머신이 있다면 필요에 따라 손쉽게 조정할 수 있습니다.

Colab을 통해 작업하고 싶다면 Colab Connect 메뉴의 '맞춤 GCE VM에 연결'을 통해 Colab용 멀티 GPU VM을 직접 프로비저닝할 수도 있습니다.

여기서는 Kaggle의 무료 TPU를 사용하는 데 중점을 두겠습니다.

시작하기 전에

Kaggle 사용자 인증 정보

Gemma 모델은 Kaggle에서 호스팅됩니다. Gemma를 사용하려면 Kaggle에서 액세스를 요청하세요.

  • kaggle.com에서 로그인 또는 등록합니다.
  • Gemma 모델 카드를 열고 '액세스 요청'을 선택합니다.
  • 동의 양식을 작성하고 이용약관에 동의합니다.

그런 다음 Kaggle API를 사용하기 위해 API 토큰을 만듭니다.

  • Kaggle 설정을 엽니다.
  • 'Create New Token'(새 토큰 만들기)을 선택합니다.
  • kaggle.json 파일이 다운로드됩니다. 여기에는 Kaggle 사용자 인증 정보가 포함됩니다.

다음 셀을 실행하고 메시지가 표시되면 Kaggle 사용자 인증 정보를 입력합니다.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

다른 방법은 kagglehub.login()을 사용할 수 없을 경우 해당 환경에서 KAGGLE_USERNAME 및 KAGGLE_KEY를 설정하는 것입니다.

설치

Gemma 모델을 사용한 Keras 및 KerasNLP를 설치합니다.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras JAX 백엔드 설정

JAX를 가져오고 TPU에서 상태 검사를 실행합니다. Kaggle은 각각 16GB의 메모리를 갖춘 8개의 TPU 코어가 있는 TPUv3-8 기기를 제공합니다.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

모델 로드

import keras
import keras_nlp

NVIDIA GPU의 혼합 정밀도 학습 참고 사항

NVIDIA GPU에서 학습할 때는 혼합 정밀도 (keras.mixed_precision.set_global_policy('mixed_bfloat16'))를 사용하여 학습 품질에 미치는 영향을 최소화하면서 학습 속도를 높일 수 있습니다. 대부분의 경우 메모리와 시간을 모두 절약하는 혼합 정밀도를 사용 설정하는 것이 좋습니다. 그러나 배치 크기가 작으면 메모리 사용량이 1.5배 늘어날 수 있습니다 (절반 정밀도 및 전체 정밀도로 가중치가 두 번 로드됨).

추론의 경우 절반 정밀도 (keras.config.set_floatx("bfloat16"))가 작동하고 메모리를 절약하지만 혼합 정밀도는 적용할 수 없습니다.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

TPU에 가중치와 텐서가 분산된 모델을 로드하려면 먼저 새 DeviceMesh를 만듭니다. DeviceMesh는 분산 계산을 위해 구성된 하드웨어 기기 모음을 나타내며 통합 배포 API의 일부로 Keras 3에 도입되었습니다.

Distribution API를 사용하면 데이터 및 모델 동시 로드를 지원하여 여러 가속기 및 호스트에서 딥 러닝 모델을 효율적으로 확장할 수 있습니다. 기본 프레임워크 (예: JAX)를 활용하여 SPMD (Single Program, Multiple Data) 확장이라는 절차를 통해 샤딩 지시어에 따라 프로그램과 텐서를 배포합니다. 새로운 Keras 3 배포 API 가이드에서 자세한 내용을 확인하세요.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

배포 API의 LayoutMap는 문자열 키(예: 정규식처럼 텐서 경로와 일치시키기 위해 아래 token_embedding/embeddings)를 사용하여 가중치와 텐서를 샤딩하거나 복제하는 방법을 지정합니다. 일치하는 텐서는 모델 측정기준 (TPU 8개)으로 샤딩되며 나머지는 완전히 복제됩니다.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel를 사용하면 DeviceMesh의 모든 기기에서 모델 가중치 또는 활성화 텐서를 샤딩할 수 있습니다. 이 경우 일부 Gemma 7B 모델 가중치는 위에 정의된 layout_map에 따라 8개의 TPU 칩에 샤딩됩니다. 이제 모델을 분산된 방식으로 로드합니다.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

이제 모델이 올바르게 파티션이 나뉘었는지 확인합니다. decoder_block_1를 예로 들어보겠습니다.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

미세 조정 전 추론

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

모델이 시청할 만한 90년대 코미디 영화 목록을 생성합니다. 이제 Gemma 모델을 미세 조정하여 출력 스타일을 변경합니다.

IMDB로 미세 조정

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

낮은 순위 조정 (LoRA)을 사용하여 미세 조정을 수행합니다. LoRA는 모델의 전체 가중치를 고정하고 모델에 학습 가능한 새 가중치를 더 적게 삽입하여 다운스트림 태스크의 학습 가능한 매개변수 수를 크게 줄이는 미세 조정 기법입니다. 기본적으로 LoRA는 더 큰 전체 가중치 행렬을 학습에 사용할 작은 저순위 행렬 2개(AxB)로 다시 매개변수화합니다. 이 기법을 사용하면 학습의 속도와 메모리 효율성이 높아집니다.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

LoRA를 사용 설정하면 학습 가능한 매개변수의 수가 70억 개에서 1, 100만 개로 크게 줄어듭니다.

미세 조정 후 추론

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

미세 조정 후 모델은 영화 리뷰 스타일을 학습했으며 이제 90년대 코미디 영화의 맥락에서 해당 스타일로 출력을 생성하고 있습니다.

다음 단계

이 튜토리얼에서는 KerasNLP JAX 백엔드를 사용하여 강력한 TPU에서 분산된 방식으로 IMDb 데이터 세트에서 Gemma 모델을 미세 조정하는 방법을 배웠습니다. 그 밖에 알아야 할 몇 가지 제안사항은 다음과 같습니다.