Xem trên ai.google.dev | Chạy trong Google Colab | Mở trong Vertex AI | Xem nguồn trên GitHub |
Hướng dẫn này minh hoạ cách thực hiện lấy mẫu/suy luận cơ bản với mô hình Hướng dẫn RecurrentGemma 2B bằng cách sử dụng thư viện recurrentgemma
của Google DeepMind được viết bằng JAX (một thư viện điện toán số hiệu suất cao), Flax (thư viện mạng nơron dựa trên JAX), Orbax (một thư viện dựa trên JAX (một thư viện dựa trên JAX) để huấn luyện {12token} mã thông báo như checkpointing).SentencePiece Mặc dù Flax không được sử dụng trực tiếp trong sổ tay này, nhưng Flax đã được sử dụng để tạo Gemma và RecurrentGemma (mô hình Griffin).
Sổ tay này có thể chạy trên Google Colab với GPU T4 (chuyển đến phần Chỉnh sửa > Cài đặt sổ tay > Trong phần Trình tăng tốc phần cứng, hãy chọn GPU T4).
Thiết lập
Các phần sau đây giải thích các bước chuẩn bị để sử dụng mô hình RecurrentGemma cho sổ tay, bao gồm cả quyền truy cập vào mô hình, lấy khoá API và định cấu hình thời gian chạy của sổ tay
Thiết lập quyền truy cập vào Kaggle cho Gemma
Để hoàn tất hướng dẫn này, trước tiên, bạn cần làm theo các hướng dẫn thiết lập tương tự như thiết lập Gemma với một số ngoại lệ:
- Truy cập vào RecurrentGemma (thay vì Gemma) trên kaggle.com.
- Chọn một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình RecurrentGemma.
- Tạo và định cấu hình tên người dùng Kaggle và khoá API.
Sau khi hoàn tất quy trình thiết lập RecurrentGemma, hãy chuyển sang phần tiếp theo để thiết lập các biến môi trường cho môi trường Colab của bạn.
Đặt các biến môi trường
Thiết lập các biến môi trường cho KAGGLE_USERNAME
và KAGGLE_KEY
. Khi được nhắc "Cấp quyền truy cập?" tin nhắn, đồng ý cấp quyền truy cập bí mật.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Cài đặt thư viện recurrentgemma
Sổ tay này tập trung vào việc sử dụng GPU Colab miễn phí. Để bật chế độ tăng tốc phần cứng, hãy nhấp vào Edit (Chỉnh sửa) > Cài đặt sổ tay > Chọn GPU T4 > Lưu.
Tiếp theo, bạn cần cài đặt thư viện Google DeepMind recurrentgemma
từ github.com/google-deepmind/recurrentgemma
. Nếu gặp lỗi "trình phân giải phần phụ thuộc của pip", bạn thường có thể bỏ qua lỗi đó.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Tải và chuẩn bị mô hình RecurrentGemma
- Tải mô hình RecurrentGemma bằng
kagglehub.model_download
. Thao tác này sẽ nhận 3 đối số:
handle
: Tên người dùng mô hình trong Kagglepath
: (Chuỗi không bắt buộc) Đường dẫn cục bộforce_download
: (Boolean không bắt buộc) Buộc tải lại mô hình xuống
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Kiểm tra vị trí của trọng số mô hình và trình tạo mã thông báo, sau đó đặt các biến đường dẫn. Thư mục tokenizer sẽ nằm trong thư mục chính mà bạn đã tải mô hình xuống, còn trọng số của mô hình sẽ nằm trong thư mục con. Ví dụ:
- Tệp
tokenizer.model
sẽ nằm trong/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
). - Điểm kiểm tra mô hình sẽ nằm trong
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Thực hiện lấy mẫu/suy luận
- Tải điểm kiểm tra mô hình RecurrentGemma bằng phương thức
recurrentgemma.jax.load_parameters
. Đối sốsharding
được đặt thành"single_device"
sẽ tải tất cả tham số của mô hình trên một thiết bị.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Tải bộ mã hoá mã thông báo mô hình RecurrentGemma, được tạo bằng
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Để tự động tải cấu hình chính xác từ điểm kiểm tra mô hình RecurrentGemma, hãy sử dụng
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Sau đó, tạo thực thể cho mô hình Griffin bằngrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Tạo
sampler
córecurrentgemma.jax.Sampler
ở đầu điểm kiểm tra/trọng số mô hình RecurrentGemma và trình tạo mã thông báo:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Viết một câu lệnh trong
prompt
và tiến hành suy luận. Bạn có thể tinh chỉnhtotal_generation_steps
(số bước được thực hiện khi tạo phản hồi – ví dụ này sử dụng50
để bảo toàn bộ nhớ máy chủ).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Tìm hiểu thêm
- Bạn có thể tìm hiểu thêm về thư viện Google DeepMind
recurrentgemma
trên GitHub. Thư viện này chứa các chuỗi tài liệu của các phương thức và mô-đun mà bạn đã sử dụng trong hướng dẫn này, chẳng hạn nhưrecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
vàrecurrentgemma.jax.Sampler
. - Các thư viện sau đây có các trang web tài liệu riêng: core JAX, Flax và Orbax.
- Để xem tài liệu về trình tạo mã thông báo/trình huỷ mã thông báo
sentencepiece
, hãy tham khảo kho lưu trữ GitHubsentencepiece
của Google. - Để xem tài liệu về
kagglehub
, hãy tham khảoREADME.md
trên kho lưu trữ GitHubkagglehub
của Kaggle. - Tìm hiểu cách sử dụng mô hình Gemma với Vertex AI của Google Cloud.
- Hãy xem RecurrentGemma: Di chuyển bộ chuyển đổi trước đây bài viết về Mô hình ngôn ngữ mở hiệu quả của Google DeepMind.
- Đọc Griffin: Kết hợp lặp lại tuyến tính có cổng vào với Bài viết của GoogleDeepMind về chú ý cục bộ dành cho mô hình ngôn ngữ hiệu quả để tìm hiểu thêm về cấu trúc mô hình mà RecurrentGemma sử dụng.