Suy luận bằng RecurrentGemma bằng JAX và Flax

Xem trên ai.google.dev Chạy trong Google Colab Mở trong Vertex AI Xem nguồn trên GitHub

Hướng dẫn này minh hoạ cách thực hiện lấy mẫu/suy luận cơ bản với mô hình Hướng dẫn RecurrentGemma 2B bằng cách sử dụng thư viện recurrentgemma của Google DeepMind được viết bằng JAX (một thư viện điện toán số hiệu suất cao), Flax (thư viện mạng nơron dựa trên JAX), Orbax (một thư viện dựa trên JAX (một thư viện dựa trên JAX) để huấn luyện {12token} mã thông báo như checkpointing).SentencePiece Mặc dù Flax không được sử dụng trực tiếp trong sổ tay này, nhưng Flax đã được sử dụng để tạo Gemma và RecurrentGemma (mô hình Griffin).

Sổ tay này có thể chạy trên Google Colab với GPU T4 (chuyển đến phần Chỉnh sửa > Cài đặt sổ tay > Trong phần Trình tăng tốc phần cứng, hãy chọn GPU T4).

Thiết lập

Các phần sau đây giải thích các bước chuẩn bị để sử dụng mô hình RecurrentGemma cho sổ tay, bao gồm cả quyền truy cập vào mô hình, lấy khoá API và định cấu hình thời gian chạy của sổ tay

Thiết lập quyền truy cập vào Kaggle cho Gemma

Để hoàn tất hướng dẫn này, trước tiên, bạn cần làm theo các hướng dẫn thiết lập tương tự như thiết lập Gemma với một số ngoại lệ:

  • Truy cập vào RecurrentGemma (thay vì Gemma) trên kaggle.com.
  • Chọn một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình RecurrentGemma.
  • Tạo và định cấu hình tên người dùng Kaggle và khoá API.

Sau khi hoàn tất quy trình thiết lập RecurrentGemma, hãy chuyển sang phần tiếp theo để thiết lập các biến môi trường cho môi trường Colab của bạn.

Đặt các biến môi trường

Thiết lập các biến môi trường cho KAGGLE_USERNAMEKAGGLE_KEY. Khi được nhắc "Cấp quyền truy cập?" tin nhắn, đồng ý cấp quyền truy cập bí mật.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Cài đặt thư viện recurrentgemma

Sổ tay này tập trung vào việc sử dụng GPU Colab miễn phí. Để bật chế độ tăng tốc phần cứng, hãy nhấp vào Edit (Chỉnh sửa) > Cài đặt sổ tay > Chọn GPU T4 > Lưu.

Tiếp theo, bạn cần cài đặt thư viện Google DeepMind recurrentgemma từ github.com/google-deepmind/recurrentgemma. Nếu gặp lỗi "trình phân giải phần phụ thuộc của pip", bạn thường có thể bỏ qua lỗi đó.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Tải và chuẩn bị mô hình RecurrentGemma

  1. Tải mô hình RecurrentGemma bằng kagglehub.model_download. Thao tác này sẽ nhận 3 đối số:
  • handle: Tên người dùng mô hình trong Kaggle
  • path: (Chuỗi không bắt buộc) Đường dẫn cục bộ
  • force_download: (Boolean không bắt buộc) Buộc tải lại mô hình xuống
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Kiểm tra vị trí của trọng số mô hình và trình tạo mã thông báo, sau đó đặt các biến đường dẫn. Thư mục tokenizer sẽ nằm trong thư mục chính mà bạn đã tải mô hình xuống, còn trọng số của mô hình sẽ nằm trong thư mục con. Ví dụ:
  • Tệp tokenizer.model sẽ nằm trong /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • Điểm kiểm tra mô hình sẽ nằm trong /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Thực hiện lấy mẫu/suy luận

  1. Tải điểm kiểm tra mô hình RecurrentGemma bằng phương thức recurrentgemma.jax.load_parameters. Đối số sharding được đặt thành "single_device" sẽ tải tất cả tham số của mô hình trên một thiết bị.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Tải bộ mã hoá mã thông báo mô hình RecurrentGemma, được tạo bằng sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Để tự động tải cấu hình chính xác từ điểm kiểm tra mô hình RecurrentGemma, hãy sử dụng recurrentgemma.GriffinConfig.from_flax_params_or_variables. Sau đó, tạo thực thể cho mô hình Griffin bằng recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Tạo samplerrecurrentgemma.jax.Sampler ở đầu điểm kiểm tra/trọng số mô hình RecurrentGemma và trình tạo mã thông báo:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Viết một câu lệnh trong prompt và tiến hành suy luận. Bạn có thể tinh chỉnh total_generation_steps (số bước được thực hiện khi tạo phản hồi – ví dụ này sử dụng 50 để bảo toàn bộ nhớ máy chủ).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Tìm hiểu thêm