Menjalankan inferensi dengan Gemma menggunakan Keras

Lihat di ai.google.dev Berjalan di Google Colab Buka di Vertex AI Lihat sumber di GitHub

Tutorial ini menunjukkan cara menggunakan Gemma dengan KerasNLP untuk menjalankan inferensi dan menghasilkan teks. Gemma adalah sekumpulan model terbuka yang ringan dan canggih, dibangun dari riset dan teknologi yang sama dengan yang digunakan untuk membuat model Gemini. KerasNLP adalah kumpulan model natural language processing (NLP) yang diimplementasikan di Keras serta dapat dijalankan di JAX, PyTorch, dan TensorFlow.

Dalam tutorial ini, Anda akan menggunakan Gemma untuk menghasilkan respons teks untuk beberapa perintah. Jika baru menggunakan Keras, Anda mungkin ingin membaca Getting started with Keras sebelum memulai, tetapi Anda tidak perlu melakukannya. Anda akan mempelajari Keras lebih lanjut selagi mengerjakan tutorial ini.

Penyiapan

Penyiapan Gemma

Untuk menyelesaikan tutorial ini, Anda harus menyelesaikan petunjuk penyiapan terlebih dahulu di penyiapan Gemma. Petunjuk penyiapan Gemma menunjukkan kepada Anda cara melakukan hal berikut:

  • Dapatkan akses ke Gemma di kaggle.com.
  • Pilih runtime Colab dengan resource yang memadai untuk menjalankan model Gemma 2B.
  • Membuat dan mengkonfigurasi nama pengguna dan kunci API Kaggle.

Setelah Anda menyelesaikan penyiapan Gemma, lanjutkan ke bagian berikutnya, untuk menetapkan variabel lingkungan untuk lingkungan Colab Anda.

Menetapkan variabel lingkungan

Menetapkan variabel lingkungan untuk KAGGLE_USERNAME dan KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Menginstal dependensi

Instal Keras dan KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Pilih backend

Keras adalah API deep learning multi-framework tingkat tinggi yang dirancang untuk kesederhanaan dan kemudahan penggunaan. Keras 3 memungkinkan Anda memilih backend: TensorFlow, JAX, atau PyTorch. Ketiganya akan berfungsi untuk tutorial ini.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Mengimpor paket

Impor Keras dan KerasNLP.

import keras
import keras_nlp

Membuat model

KerasNLP menyediakan implementasi dari banyak arsitektur model populer. Dalam tutorial ini, Anda akan membuat model menggunakan GemmaCausalLM, model Gemma menyeluruh untuk pemodelan bahasa kausal. Model bahasa kausal memprediksi token berikutnya berdasarkan token sebelumnya.

Buat model menggunakan metode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

Fungsi GemmaCausalLM.from_preset() membuat instance model dari bobot dan arsitektur preset. Dalam kode di atas, string "gemma_2b_en" menentukan preset model Gemma 2B dengan 2 miliar parameter. Model Gemma dengan parameter 7B, 9B, dan 27B juga tersedia. Anda dapat menemukan string kode untuk model Gemma dalam listingan Variasi Model di kaggle.com.

Gunakan summary untuk mendapatkan info selengkapnya tentang model ini:

gemma_lm.summary()

Seperti yang dapat Anda lihat dari ringkasan, model ini memiliki 2,5 miliar parameter yang dapat dilatih.

Membuat teks

Sekarang saatnya membuat beberapa teks! Model ini memiliki metode generate yang menghasilkan teks berdasarkan prompt. Argumen max_length opsional menentukan panjang maksimum urutan yang dihasilkan.

Cobalah dengan perintah "What is the meaning of life?".

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Coba panggil generate lagi dengan perintah lain.

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

Jika menjalankan backend JAX atau TensorFlow, Anda akan melihat bahwa panggilan generate kedua ditampilkan hampir seketika. Hal ini karena setiap panggilan ke generate untuk ukuran batch tertentu dan max_length dikompilasi dengan XLA. Proses pertama berbiaya mahal, tetapi operasi berikutnya jauh lebih cepat.

Anda juga dapat menyediakan perintah batch menggunakan daftar sebagai input:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

Opsional: Coba sampel lain

Anda dapat mengontrol strategi pembuatan untuk GemmaCausalLM dengan menetapkan argumen sampler di compile(). Secara default, pengambilan sampel "greedy" akan digunakan.

Sebagai eksperimen, coba tetapkan strategi "top_k":

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

Meskipun algoritma greedy default selalu mengambil token dengan probabilitas terbesar, algoritma top-K secara acak mengambil token berikutnya dari token probabilitas top K.

Anda tidak perlu menentukan sampler, dan Anda dapat mengabaikan cuplikan kode terakhir jika tidak membantu kasus penggunaan Anda. Jika Anda ingin mempelajari lebih lanjut sampler yang tersedia, lihat Sampler.

Langkah selanjutnya

Dalam tutorial ini, Anda telah mempelajari cara membuat teks menggunakan KerasNLP dan Gemma. Berikut adalah beberapa saran tentang hal-hal yang perlu dipelajari selanjutnya: