Inferensi dengan Gemma menggunakan JAX dan Flax

Lihat di ai.google.dev Menjalankan di Google Colab Terbuka di Vertex AI Lihat sumber di GitHub

Ringkasan

Gemma adalah kumpulan model bahasa besar terbuka yang ringan dan canggih, berdasarkan riset dan teknologi Google DeepMind Gemini. Tutorial ini menunjukkan cara melakukan pengambilan sampel/inferensi dasar dengan model Gemma 2B Instructions menggunakan library gemma Google DeepMind yang ditulis dengan JAX (library komputasi numerik berperforma tinggi), Flax (library jaringan neural berbasis JAX), Orbax (library berbasis JAX untuk utilitas pelatihan seperti checkpointing), dan SentencePiece Meskipun Flax tidak digunakan langsung di notebook ini, Flax digunakan untuk membuat Gemma.

Notebook ini dapat dijalankan di Google Colab dengan GPU T4 gratis (buka Edit > Notebook settings > Di bagian Hardware Accelerate, pilih T4 GPU).

Penyiapan

1. Menyiapkan akses Kaggle untuk Gemma

Untuk menyelesaikan tutorial ini, Anda harus mengikuti petunjuk penyiapan terlebih dahulu di Penyiapan Gemma, yang menunjukkan cara melakukan hal berikut:

  • Dapatkan akses ke Gemma di kaggle.com.
  • Pilih runtime Colab dengan resource yang memadai untuk menjalankan model Gemma.
  • Membuat dan mengkonfigurasi nama pengguna dan kunci API Kaggle.

Setelah menyelesaikan penyiapan Gemma, lanjutkan ke bagian berikutnya yang berisi cara menetapkan variabel lingkungan untuk lingkungan Colab Anda.

2. Menetapkan variabel lingkungan

Tetapkan variabel lingkungan untuk KAGGLE_USERNAME dan KAGGLE_KEY. Saat diminta dengan pesan "Berikan akses?", setujui untuk memberikan akses rahasia.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Menginstal library gemma

Notebook ini berfokus pada penggunaan GPU Colab gratis. Untuk mengaktifkan akselerasi hardware, klik Edit > Setelan notebook > Pilih T4 GPU > Save.

Selanjutnya, Anda perlu menginstal library gemma Google DeepMind dari github.com/google-deepmind/gemma. Jika Anda mendapatkan pesan error tentang "resolver dependensi pip", Anda biasanya dapat mengabaikannya.

pip install -q git+https://github.com/google-deepmind/gemma.git

Memuat dan menyiapkan model Gemma

  1. Muat model Gemma dengan kagglehub.model_download, yang menggunakan tiga argumen:
  • handle: Tuas model dari Kaggle
  • path: (String opsional) Jalur lokal
  • force_download: (Boolean opsional) Memaksa mendownload ulang model
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. Periksa lokasi bobot model dan tokenizer, lalu tetapkan variabel jalur. Direktori tokenizer akan berada di direktori utama tempat Anda mendownload model, sedangkan bobot model akan berada di sub-direktori. Contoh:
  • File tokenizer.model akan berada di /LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • Checkpoint model akan berada di /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it).
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

Melakukan pengambilan sampel/inferensi

  1. Muat dan format checkpoint model Gemma dengan metode gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Muat tokenizer Gemma yang dibuat menggunakan sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Untuk otomatis memuat konfigurasi yang benar dari checkpoint model Gemma, gunakan gemma.transformer.TransformerConfig. Argumen cache_size adalah jumlah langkah waktu dalam cache Transformer Gemma. Setelah itu, buat instance model Gemma sebagai transformer dengan gemma.transformer.Transformer (yang mewarisi dari flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Buat sampler dengan gemma.sampler.Sampler selain checkpoint/bobot model Gemma dan tokenizer:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Tulis prompt di input_batch dan lakukan inferensi. Anda dapat menyesuaikan total_generation_steps (jumlah langkah yang dilakukan saat membuat respons — contoh ini menggunakan 100 untuk menghemat memori host).
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (Opsional) Jalankan sel ini untuk mengosongkan memori jika Anda telah menyelesaikan notebook dan ingin mencoba perintah lain. Setelah itu, Anda dapat membuat instance sampler lagi di langkah 3, serta menyesuaikan dan menjalankan perintah di langkah 4.
del sampler

Pelajari lebih lanjut