Lihat di ai.google.dev | Berjalan di Google Colab | Buka di Vertex AI | Lihat sumber di GitHub |
Tutorial ini menunjukkan cara melakukan pengambilan sampel/inferensi dasar dengan model instruksi 2B RecurrentGemma menggunakan library recurrentgemma
Google DeepMind yang ditulis dengan JAX (library komputasi numerik berperforma tinggi), Flax (library jaringan neural berbasis JAX), Orbax (library berbasis JAX untuk utilitas pelatihan seperti checkpointing/token{i> <i}seperti{i> <i}{i>checkpointizer<i}{/12), dan {Sentizer{i> <i}{i>Sentizer} JAX <i}token{i> tokenizer <i}").SentencePiece Meskipun Flax tidak digunakan secara langsung di notebook ini, Flax digunakan untuk membuat Gemma dan RecurrentGemma (model Griffin).
Notebook ini dapat berjalan di Google Colab dengan GPU T4 (buka Edit > Setelan notebook > Di bagian Akselerator hardware, pilih T4 GPU).
Penyiapan
Bagian berikut menjelaskan langkah-langkah untuk menyiapkan notebook agar dapat menggunakan model RecurrentGemma, termasuk akses model, mendapatkan kunci API, dan mengonfigurasi runtime notebook
Menyiapkan akses Kaggle untuk Gemma
Untuk menyelesaikan tutorial ini, pertama-tama Anda harus mengikuti petunjuk penyiapan yang mirip dengan penyiapan Gemma dengan beberapa pengecualian:
- Dapatkan akses ke RecurrentGemma (bukan Gemma) di kaggle.com.
- Pilih runtime Colab dengan resource yang memadai untuk menjalankan model RecurrentGemma.
- Membuat dan mengkonfigurasi nama pengguna dan kunci API Kaggle.
Setelah Anda menyelesaikan penyiapan RecurrentGemma, lanjutkan ke bagian berikutnya, tempat Anda akan menetapkan variabel lingkungan untuk lingkungan Colab Anda.
Menetapkan variabel lingkungan
Menetapkan variabel lingkungan untuk KAGGLE_USERNAME
dan KAGGLE_KEY
. Saat melihat dialog "Berikan akses?", pesan, setuju untuk memberikan
akses rahasia.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Menginstal library recurrentgemma
Notebook ini berfokus pada penggunaan GPU Colab gratis. Untuk mengaktifkan akselerasi hardware, klik Edit > Setelan notebook > Pilih T4 GPU > Simpan.
Selanjutnya, Anda perlu menginstal library recurrentgemma
Google DeepMind dari github.com/google-deepmind/recurrentgemma
. Jika mendapatkan error tentang "resolver dependensi pip", Anda biasanya bisa mengabaikannya.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Memuat dan menyiapkan model RecurrentGemma
- Muat model RecurrentGemma dengan
kagglehub.model_download
, yang menggunakan tiga argumen:
handle
: Handle model dari Kagglepath
: (String opsional) Jalur lokalforce_download
: (Boolean opsional) Memaksa untuk mendownload ulang model
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Periksa lokasi bobot model dan tokenizer, lalu tetapkan variabel jalur. Direktori tokenizer akan berada di direktori utama tempat Anda mendownload model, sedangkan bobot model akan berada di sub-direktori. Contoh:
- File
tokenizer.model
akan berada di/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
). - Checkpoint model akan berada di
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Melakukan sampling/inferensi
- Muat checkpoint model RecurrentGemma dengan metode
recurrentgemma.jax.load_parameters
. Argumensharding
yang ditetapkan ke"single_device"
memuat semua parameter model di satu perangkat.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Muat tokenizer model RecurrentGemma, yang dibuat menggunakan
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Untuk otomatis memuat konfigurasi yang benar dari checkpoint model RecurrentGemma, gunakan
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Kemudian, buat instance model Griffin denganrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Buat
sampler
denganrecurrentgemma.jax.Sampler
di atas checkpoint/bobot model RecurrentGemma dan tokenizer:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Tulis perintah di
prompt
dan lakukan inferensi. Anda dapat menyesuaikantotal_generation_steps
(jumlah langkah yang dilakukan saat membuat respons — contoh ini menggunakan50
untuk menghemat memori host).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Pelajari lebih lanjut
- Anda dapat mempelajari lebih lanjut library
recurrentgemma
Google DeepMind di GitHub, yang berisi dokumen metode dan modul yang Anda gunakan dalam tutorial ini, sepertirecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
, danrecurrentgemma.jax.Sampler
. - Library berikut memiliki situs dokumentasinya sendiri: core JAX, Flax, dan Orbax.
- Untuk dokumentasi tokenizer/detokenizer
sentencepiece
, lihat repo GitHubsentencepiece
Google. - Untuk dokumentasi
kagglehub
, lihatREADME.md
di repo GitHubkagglehub
Kaggle. - Pelajari cara menggunakan model Gemma dengan Vertex AI Google Cloud.
- Lihat RecurrentGemma: Moving Past Transformers untuk Makalah Model Bahasa Terbuka yang Efisien oleh Google DeepMind.
- Baca artikel Griffin: Mencampur Pengulangan Linear Terbatas dengan Makalah Local Attention for Efficient Language Models oleh GoogleDeepMind untuk mempelajari lebih lanjut arsitektur model yang digunakan oleh RecurrentGemma.