Dự đoán cùng Gemma bằng JAX và Flax

Xem trên ai.google.dev Chạy trong Google Colab Mở trong Vertex AI Xem nguồn trên GitHub

Tổng quan

Gemma là một dòng mô hình ngôn ngữ lớn, hiện đại và gọn nhẹ, dựa trên nghiên cứu và công nghệ của Google DeepMind Gemini. Phần hướng dẫn này minh hoạ cách thực hiện lấy mẫu/dự đoán cơ bản bằng mô hình Hướng dẫn Gemma 2B bằng thư viện gemma của Google DeepMind được viết bằng JAX (thư viện tính toán số hiệu suất cao), Flax (thư viện mạng nơron dựa trên JAX), Orbax (thư viện dựa trên JAX dành cho các tiện ích huấn luyện như checkpointing) và SentencePiece Mặc dù Flax không được sử dụng trực tiếp trong sổ tay này, nhưng Flax đã được sử dụng để tạo Gemma.

Sổ tay này có thể chạy trên Google Colab với GPU T4 miễn phí (chuyển đến phần Chỉnh sửa > Cài đặt sổ tay > Trong phần Trình tăng tốc phần cứng, hãy chọn T4 GPU).

Thiết lập

1. Thiết lập quyền truy cập Kaggle cho Gemma

Để hoàn tất hướng dẫn này, trước tiên bạn cần làm theo hướng dẫn thiết lập trong bài viết Thiết lập Gemma. Phần này hướng dẫn bạn cách thực hiện những thao tác sau:

  • Truy cập vào Gemma trên kaggle.com.
  • Chọn một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình Gemma.
  • Tạo và định cấu hình tên người dùng Kaggle và khoá API.

Sau khi hoàn tất quy trình thiết lập Gemma, hãy chuyển sang phần tiếp theo. Tại đây, bạn sẽ đặt các biến môi trường cho môi trường Colab của mình.

2. Đặt các biến môi trường

Thiết lập các biến môi trường cho KAGGLE_USERNAMEKAGGLE_KEY. Khi được nhắc với thông báo "Cấp quyền truy cập?", bạn hãy đồng ý cung cấp quyền truy cập bí mật.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Cài đặt thư viện gemma

Sổ tay này tập trung vào việc sử dụng GPU Colab miễn phí. Để bật tính năng tăng tốc phần cứng, hãy nhấp vào Chỉnh sửa > Cài đặt sổ tay > Chọn GPU T4 > Lưu.

Tiếp theo, bạn cần cài đặt thư viện Google DeepMind gemma từ github.com/google-deepmind/gemma. Nếu gặp lỗi về "trình phân giải phần phụ thuộc của pip", bạn thường có thể bỏ qua lỗi này.

pip install -q git+https://github.com/google-deepmind/gemma.git

Tải và chuẩn bị mô hình Gemma

  1. Tải mô hình Gemma bằng kagglehub.model_download, sẽ lấy 3 đối số:
  • handle: Tên người dùng mô hình trong Kaggle
  • path: (Chuỗi không bắt buộc) Đường dẫn cục bộ
  • force_download: (Boolean không bắt buộc) Buộc tải lại mô hình xuống
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. Kiểm tra vị trí của trọng số của mô hình và trình tạo mã thông báo, sau đó đặt các biến đường dẫn. Thư mục của trình tạo mã thông báo sẽ nằm trong thư mục chính mà bạn đã tải mô hình xuống, còn trọng số của mô hình sẽ nằm trong thư mục con. Ví dụ:
  • Tệp tokenizer.model sẽ nằm trong /LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • Điểm kiểm tra của mô hình sẽ nằm ở /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it).
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

Thực hiện lấy mẫu/dự đoán

  1. Tải và định dạng điểm kiểm tra mô hình Gemma bằng phương thức gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Tải trình tạo mã thông báo Gemma, được tạo bằng sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Để tự động tải cấu hình chính xác từ điểm kiểm tra mô hình Gemma, hãy sử dụng gemma.transformer.TransformerConfig. Đối số cache_size là số bước thời gian trong bộ nhớ đệm Transformer của Gemma. Sau đó, hãy tạo thực thể cho mô hình Gemma dưới dạng transformer bằng gemma.transformer.Transformer (kế thừa từ flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Tạo samplergemma.sampler.Sampler ở trên điểm kiểm tra/trọng số của mô hình Gemma và trình tạo mã thông báo:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Viết một câu lệnh trong input_batch và tiến hành dự đoán. Bạn có thể điều chỉnh total_generation_steps (số bước được thực hiện khi tạo phản hồi – ví dụ này sử dụng 100 để duy trì bộ nhớ máy chủ).
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (Không bắt buộc) Chạy ô này để giải phóng bộ nhớ nếu bạn đã hoàn thành sổ tay và muốn thử một lời nhắc khác. Sau đó, bạn có thể tạo lại thực thể cho sampler ở bước 3, cũng như tuỳ chỉnh và chạy lời nhắc ở bước 4.
del sampler

Tìm hiểu thêm