Inferensi dengan RecurrentGemma menggunakan JAX dan Flax

Lihat di ai.google.dev Berjalan di Google Colab Buka di Vertex AI Lihat sumber di GitHub

Tutorial ini menunjukkan cara melakukan pengambilan sampel/inferensi dasar dengan model instruksi 2B RecurrentGemma menggunakan library recurrentgemma Google DeepMind yang ditulis dengan JAX (library komputasi numerik berperforma tinggi), Flax (library jaringan neural berbasis JAX), Orbax (library berbasis JAX untuk utilitas pelatihan seperti checkpointing/token{i> <i}seperti{i> <i}{i>checkpointizer<i}{/12), dan {Sentizer{i> <i}{i>Sentizer} JAX <i}token{i> tokenizer <i}").SentencePiece Meskipun Flax tidak digunakan secara langsung di notebook ini, Flax digunakan untuk membuat Gemma dan RecurrentGemma (model Griffin).

Notebook ini dapat berjalan di Google Colab dengan GPU T4 (buka Edit > Setelan notebook > Di bagian Akselerator hardware, pilih T4 GPU).

Penyiapan

Bagian berikut menjelaskan langkah-langkah untuk menyiapkan notebook agar dapat menggunakan model RecurrentGemma, termasuk akses model, mendapatkan kunci API, dan mengonfigurasi runtime notebook

Menyiapkan akses Kaggle untuk Gemma

Untuk menyelesaikan tutorial ini, pertama-tama Anda harus mengikuti petunjuk penyiapan yang mirip dengan penyiapan Gemma dengan beberapa pengecualian:

  • Dapatkan akses ke RecurrentGemma (bukan Gemma) di kaggle.com.
  • Pilih runtime Colab dengan resource yang memadai untuk menjalankan model RecurrentGemma.
  • Membuat dan mengkonfigurasi nama pengguna dan kunci API Kaggle.

Setelah Anda menyelesaikan penyiapan RecurrentGemma, lanjutkan ke bagian berikutnya, tempat Anda akan menetapkan variabel lingkungan untuk lingkungan Colab Anda.

Menetapkan variabel lingkungan

Menetapkan variabel lingkungan untuk KAGGLE_USERNAME dan KAGGLE_KEY. Saat melihat dialog "Berikan akses?", pesan, setuju untuk memberikan akses rahasia.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Menginstal library recurrentgemma

Notebook ini berfokus pada penggunaan GPU Colab gratis. Untuk mengaktifkan akselerasi hardware, klik Edit > Setelan notebook > Pilih T4 GPU > Simpan.

Selanjutnya, Anda perlu menginstal library recurrentgemma Google DeepMind dari github.com/google-deepmind/recurrentgemma. Jika mendapatkan error tentang "resolver dependensi pip", Anda biasanya bisa mengabaikannya.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Memuat dan menyiapkan model RecurrentGemma

  1. Muat model RecurrentGemma dengan kagglehub.model_download, yang menggunakan tiga argumen:
  • handle: Handle model dari Kaggle
  • path: (String opsional) Jalur lokal
  • force_download: (Boolean opsional) Memaksa untuk mendownload ulang model
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Periksa lokasi bobot model dan tokenizer, lalu tetapkan variabel jalur. Direktori tokenizer akan berada di direktori utama tempat Anda mendownload model, sedangkan bobot model akan berada di sub-direktori. Contoh:
  • File tokenizer.model akan berada di /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • Checkpoint model akan berada di /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Melakukan sampling/inferensi

  1. Muat checkpoint model RecurrentGemma dengan metode recurrentgemma.jax.load_parameters. Argumen sharding yang ditetapkan ke "single_device" memuat semua parameter model di satu perangkat.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Muat tokenizer model RecurrentGemma, yang dibuat menggunakan sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Untuk otomatis memuat konfigurasi yang benar dari checkpoint model RecurrentGemma, gunakan recurrentgemma.GriffinConfig.from_flax_params_or_variables. Kemudian, buat instance model Griffin dengan recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Buat sampler dengan recurrentgemma.jax.Sampler di atas checkpoint/bobot model RecurrentGemma dan tokenizer:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Tulis perintah di prompt dan lakukan inferensi. Anda dapat menyesuaikan total_generation_steps (jumlah langkah yang dilakukan saat membuat respons — contoh ini menggunakan 50 untuk menghemat memori host).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Pelajari lebih lanjut