![]() | ![]() | ![]() | | ![]() |
Vështrim i përgjithshëm
Gemma është një familje modelesh të hapura me peshë të lehtë dhe moderne të ndërtuara nga kërkimi dhe teknologjia e përdorur për krijimin e modeleve të Google Gemini. Gemma mund të rregullohet më tej për t'iu përshtatur nevojave specifike. Por modelet e mëdha të gjuhës, të tilla si Gemma, mund të jenë shumë të mëdha në përmasa dhe disa prej tyre mund të mos përshtaten në një përshpejtues këndimi për rregullim të imët. Në këtë rast ekzistojnë dy qasje të përgjithshme për rregullimin e tyre të imët:
- Parametri Efficient Fine-Tuning (PEFT), i cili kërkon të zvogëlojë madhësinë efektive të modelit duke sakrifikuar njëfarë besnikërie. LoRA bie në këtë kategori dhe modelet e rregullimit të Gemma në Keras duke përdorur tutorialin LoRA demonstrojnë se si të rregulloni modelin Gemma 2B
gemma_2b_en
me LoRA duke përdorur KerasNLP në një GPU të vetme. - Rregullimi i plotë i parametrave me paralelizëm modeli. Paralelizmi i modelit shpërndan peshat e një modeli të vetëm nëpër pajisje të shumta dhe mundëson shkallëzimin horizontal. Mund të mësoni më shumë rreth trajnimit të shpërndarë në këtë udhëzues Keras .
Ky udhëzues ju udhëzon në përdorimin e Keras me një prapavijë JAX për të rregulluar modelin Gemma 7B me LoRA dhe trajnime të shpërndara me model-parallizëm në Njësinë e Përpunimit Tensor të Google (TPU). Vini re se LoRA mund të çaktivizohet në këtë tutorial për një akordim më të ngadaltë por më të saktë me parametra të plotë.
Përdorimi i përshpejtuesve
Teknikisht mund të përdorni ose TPU ose GPU për këtë tutorial.
Shënime mbi mjediset TPU
Google ka 3 produkte që ofrojnë TPU:
- Colab ofron TPU v2 falas, gjë që është e mjaftueshme për këtë tutorial.
- Kaggle ofron TPU v3 falas dhe ata gjithashtu punojnë për këtë tutorial.
- Cloud TPU ofron TPU v3 dhe gjenerata më të reja. Një mënyrë për ta konfiguruar është:
- Krijo një VM të re TPU
- Konfiguro përcjelljen e portit SSH për portin e synuar të serverit Jupyter
- Instaloni Jupyter dhe niseni në TPU VM, më pas lidheni me Colab përmes "Lidhu me një kohë ekzekutimi lokal"
Shënime mbi konfigurimin me shumë GPU
Megjithëse ky tutorial fokusohet në rastin e përdorimit të TPU-së, mund ta përshtatni lehtësisht për nevojat tuaja nëse keni një makinë me shumë GPU.
Nëse preferoni të punoni përmes Colab, është gjithashtu e mundur të siguroni një VM me shumë GPU për Colab drejtpërdrejt përmes "Lidhu me një VM të personalizuar GCE" në menynë Colab Connect.
Ne do të përqendrohemi në përdorimin e TPU-së falas nga Kaggle këtu.
Para se të filloni
Kredencialet e Kaggle
Modelet Gemma priten nga Kaggle. Për të përdorur Gemma, kërkoni qasje në Kaggle:
- Identifikohu ose regjistrohu në kaggle.com
- Hapni kartën e modelit Gemma dhe zgjidhni "Kërko qasje"
- Plotësoni formularin e pëlqimit dhe pranoni termat dhe kushtet
Më pas, për të përdorur Kaggle API, krijoni një shenjë API:
- Hapni cilësimet e Kaggle
- Zgjidhni "Krijo Token të Ri"
- Një skedar
kaggle.json
është shkarkuar. Ai përmban kredencialet tuaja Kaggle
Drejtoni qelizën e mëposhtme dhe shkruani kredencialet tuaja Kaggle kur ju kërkohet.
# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
Një mënyrë alternative është të vendosni KAGGLE_USERNAME dhe KAGGLE_KEY në mjedisin tuaj nëse kagglehub.login() nuk funksionon për ju.
Instalimi
Instaloni Keras dhe KerasNLP me modelin Gemma.
pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras
Konfiguro backend Keras JAX
Importoni JAX dhe kryeni një kontroll të arsyeshëm në TPU. Kaggle ofron pajisje TPUv3-8 të cilat kanë 8 bërthama TPU me 16 GB memorie secila.
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os
# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
Modeli i ngarkesës
import keras
import keras_nlp
Shënime mbi trajnimin me saktësi të përzier në GPU-të NVIDIA
Kur stërviteni në GPU-të NVIDIA, saktësia e përzier ( keras.mixed_precision.set_global_policy('mixed_bfloat16')
) mund të përdoret për të shpejtuar stërvitjen me efekt minimal në cilësinë e stërvitjes. Në shumicën e rasteve, rekomandohet të aktivizoni saktësinë e përzier pasi kursen kujtesën dhe kohën. Megjithatë, kini parasysh se në madhësi të vogla të grupeve, ai mund të rrisë përdorimin e memories me 1.5x (peshat do të ngarkohen dy herë, me gjysmë saktësie dhe saktësi të plotë).
Për konkluzion, gjysma e saktësisë ( keras.config.set_floatx("bfloat16")
) do të funksionojë dhe do të kursejë memorie ndërsa precizioni i përzier nuk është i zbatueshëm.
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
Për të ngarkuar modelin me peshat dhe tensorët e shpërndarë nëpër TPU, fillimisht krijoni një DeviceMesh
të ri. DeviceMesh
përfaqëson një koleksion pajisjesh harduerike të konfiguruara për llogaritje të shpërndara dhe u prezantua në Keras 3 si pjesë e API-së së unifikuar të shpërndarjes.
API-ja e shpërndarjes mundëson paralelizmin e të dhënave dhe modeleve, duke lejuar shkallëzimin efikas të modeleve të të mësuarit të thellë në përshpejtues dhe pritës të shumtë. Ai përdor kuadrin themelor (p.sh. JAX) për të shpërndarë programin dhe tensorët sipas direktivave të ndarjes përmes një procedure të quajtur zgjerimi i një programi të vetëm, të dhëna të shumëfishta (SPMD). Shikoni më shumë detaje në udhëzuesin e ri API të shpërndarjes Keras 3 .
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
LayoutMap
nga API-ja e shpërndarjes specifikon se si peshat dhe tensorët duhet të copëtohen ose përsëriten, duke përdorur çelësat e vargut, për shembull, token_embedding/embeddings
më poshtë, të cilët trajtohen si regex për të përputhur shtigjet e tensorit. Tenzorët e përputhur janë të copëtuar me dimensionet e modelit (8 TPU); të tjerat do të përsëriten plotësisht.
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)
ModelParallel
ju lejon të copëtoni peshat e modelit ose tensorët e aktivizimit në të gjitha pajisjet në DeviceMesh
. Në këtë rast, disa nga peshat e modelit Gemma 7B ndahen në 8 çipa TPU sipas layout_map
të përcaktuar më sipër. Tani ngarkoni modelin në mënyrën e shpërndarë.
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Tani verifikoni që modeli është ndarë saktë. Le të marrim si shembull decoder_block_1
.
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'> decoder_block_1/pre_attention_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/attention/query/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/key/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/value/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/attention_output/kernel (16, 256, 3072) PartitionSpec(None, None, 'model') decoder_block_1/pre_ffw_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/ffw_gating/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_gating_2/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_linear/kernel (24576, 3072) PartitionSpec(None, 'model')
Konkluzioni përpara rregullimit të imët
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'
Modelja gjeneron një listë të filmave komedi të shkëlqyer nga vitet '90 për t'u parë. Tani ne rregullojmë modelin Gemma për të ndryshuar stilin e daljes.
Përmirësohu me IMDB
import tensorflow_datasets as tfds
imdb_train = tfds.load(
"imdb_reviews",
split="train",
as_supervised=True,
batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)
imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord… Generating test examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*… Generating unsupervised examples...: 0%| | 0/50000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t… Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data. b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)
Kryeni rregullimin e imët duke përdorur Përshtatjen me gradë të ulët (LoRA). LoRA është një teknikë e rregullimit të imët e cila redukton në masë të madhe numrin e parametrave të trajnueshëm për detyrat në rrjedhën e poshtme duke ngrirë peshat e plota të modelit dhe duke futur një numër më të vogël peshash të reja të trajnueshme në model. Në thelb LoRA riparametizon matricat më të mëdha me peshë të plotë me 2 matrica më të vogla të nivelit të ulët AxB për t'u stërvitur dhe kjo teknikë e bën stërvitjen shumë më të shpejtë dhe më efikase për memorie.
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.
# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" 2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329 <keras.src.callbacks.history.History at 0x7e9cac7f41c0>
Vini re se aktivizimi i LoRA redukton ndjeshëm numrin e parametrave të trajnueshëm, nga 7 miliardë në vetëm 11 milionë.
Konkluzioni pas rregullimit të imët
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."
Pas rregullimit të imët, modelja ka mësuar stilin e rishikimeve të filmave dhe tani po gjeneron rezultate në atë stil në kontekstin e filmave komedi të viteve '90.
Çfarë është më pas
Në këtë tutorial, ju mësuat se si të përdorni KerasNLP JAX backend për të rregulluar një model Gemma në grupin e të dhënave IMDb në një mënyrë të shpërndarë në TPU-të e fuqishme. Këtu janë disa sugjerime se çfarë tjetër për të mësuar:
- Mësoni se si të filloni me Keras Gemma .
- Mësoni se si të rregulloni modelin Gemma në GPU .