Suy luận bằng CodeGemma bằng JAX và Flax

Xem trên ai.google.dev Chạy trong Google Colab Xem nguồn trên GitHub

Chúng tôi giới thiệu CodeGemma, một tập hợp các mô hình mã mở dựa trên mô hình Gemma của Google DeepMind (Gemma Team và cộng sự, năm 2024). CodeGemma là một dòng mô hình mở, gọn nhẹ, hiện đại, được xây dựng dựa trên chính nghiên cứu và công nghệ dùng để tạo ra các mô hình Gemini.

Tiếp tục từ các mô hình huấn luyện trước của Gemma, các mô hình CodeGemma được huấn luyện thêm về hơn 500 đến 1.000 tỷ token chủ yếu là mã, sử dụng có cùng kiến trúc với dòng mô hình Gemma. Do đó, các mô hình CodeGemma đều đạt được hiệu suất mã hiện đại nhất trong cả hai trường hợp hoàn thành và tạo nhiệm vụ, đồng thời vẫn duy trì được hiểu và suy luận trên quy mô lớn.

CodeGemma có 3 biến thể:

  • Mô hình huấn luyện trước bằng mã 7B
  • Mô hình mã điều chỉnh theo hướng dẫn 7B
  • Mô hình 2B, được huấn luyện riêng cho việc điền mã và tạo kết thúc mở.

Hướng dẫn này sẽ chỉ cho bạn cách sử dụng mô hình CodeGemma với Flax để hoàn thành mã.

Thiết lập

1. Thiết lập quyền truy cập vào Kaggle cho CodeGemma

Để hoàn tất hướng dẫn này, trước tiên bạn cần làm theo hướng dẫn thiết lập trong phần thiết lập Gemma. Các hướng dẫn này sẽ cho bạn biết cách thực hiện những việc sau:

  • Truy cập vào CodeGemma trên kaggle.com.
  • Chọn một môi trường thời gian chạy Colab có đủ tài nguyên (GPU T4 không đủ bộ nhớ, hãy sử dụng TPU phiên bản 2) để chạy mô hình CodeGemma.
  • Tạo và định cấu hình tên người dùng Kaggle và khoá API.

Sau khi thiết lập xong Gemma, hãy chuyển sang phần tiếp theo. Tại đây, bạn sẽ thiết lập các biến môi trường cho môi trường Colab của mình.

2. Đặt các biến môi trường

Thiết lập các biến môi trường cho KAGGLE_USERNAMEKAGGLE_KEY. Khi được nhắc "Cấp quyền truy cập?" tin nhắn, đồng ý cấp quyền truy cập bí mật.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Cài đặt thư viện gemma

Tính năng tăng tốc phần cứng miễn phí của Colab hiện không đủ để chạy sổ tay này. Nếu bạn đang sử dụng Colab Pay As You Go hoặc Colab Pro, hãy nhấp vào Chỉnh sửa > Cài đặt sổ tay > Chọn GPU A100 > Lưu để bật chế độ tăng tốc phần cứng.

Tiếp theo, bạn cần cài đặt thư viện Google DeepMind gemma từ github.com/google-deepmind/gemma. Nếu gặp lỗi "trình phân giải phần phụ thuộc của pip", bạn thường có thể bỏ qua lỗi đó.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Nhập thư viện

Sổ tay này sử dụng Gemma (sử dụng Flax để tạo các lớp mạng nơron) và SentencePiece (để mã hoá).

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

Tải mô hình CodeGemma

Tải mô hình CodeGemma bằng kagglehub.model_download. Thao tác này sẽ nhận 3 đối số:

  • handle: Tên người dùng mô hình trong Kaggle
  • path: (Chuỗi không bắt buộc) Đường dẫn cục bộ
  • force_download: (Boolean không bắt buộc) Buộc tải lại mô hình xuống
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Kiểm tra vị trí của trọng số mô hình và trình tạo mã thông báo, sau đó đặt các biến đường dẫn. Thư mục tokenizer sẽ nằm trong thư mục chính mà bạn đã tải mô hình xuống, còn trọng số của mô hình sẽ nằm trong thư mục con. Ví dụ:

  • Tệp trình tạo mã thông báo spm.model sẽ nằm trong /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • Điểm kiểm tra mô hình sẽ nằm trong /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Thực hiện lấy mẫu/suy luận

Tải và định dạng điểm kiểm tra mô hình CodeGemma bằng phương thức gemma.params.load_and_format_params:

params = params_lib.load_and_format_params(CKPT_PATH)

Tải trình tạo mã thông báo CodeGemma, được tạo bằng sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

Để tự động tải cấu hình chính xác từ điểm kiểm tra mô hình CodeGemma, hãy sử dụng gemma.transformer.TransformerConfig. Đối số cache_size là số bước thời gian trong bộ nhớ đệm Transformer của CodeGemma. Sau đó, hãy tạo thực thể cho mô hình CodeGemma dưới dạng model_2b bằng gemma.transformer.Transformer (kế thừa từ flax.linen.Module).

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

Tạo sampler bằng gemma.sampler.Sampler. Phương thức này sử dụng điểm kiểm tra mô hình CodeGemma và trình tạo mã thông báo.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Tạo một số biến để đại diện cho mã thông báo điền vào giữa (fim) và tạo một số hàm trợ giúp để định dạng lời nhắc và kết quả được tạo.

Ví dụ: hãy xem mã sau:

def function(string):
assert function('asdf') == 'fdsa'

Chúng ta cần điền vào function để câu nhận định chứa True. Trong trường hợp này, tiền tố sẽ là:

"def function(string):\n"

Và hậu tố sẽ là:

"assert function('asdf') == 'fdsa'"

Sau đó, chúng ta định dạng văn bản này thành một câu lệnh là PREFIX-HOWMANY-MIDDLE (phần giữa cần được điền luôn luôn ở cuối lời nhắc):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

Tạo lời nhắc và tiến hành suy luận. Chỉ định văn bản tiền tố before và văn bản có hậu tố after, rồi tạo câu lệnh có định dạng bằng hàm trợ giúp format_completion prompt.

Bạn có thể tinh chỉnh total_generation_steps (số bước được thực hiện khi tạo phản hồi – ví dụ này sử dụng 100 để bảo toàn bộ nhớ máy chủ).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

Tìm hiểu thêm

  • Bạn có thể tìm hiểu thêm về thư viện Google DeepMind gemma trên GitHub. Thư viện này chứa các chuỗi tài liệu của các mô-đun mà bạn đã sử dụng trong hướng dẫn này, chẳng hạn như gemma.params, gemma.transformergemma.sampler.
  • Các thư viện sau đây có các trang web tài liệu riêng: core JAX, FlaxOrbax.
  • Để xem tài liệu về trình tạo mã thông báo/trình huỷ mã thông báo sentencepiece, hãy tham khảo kho lưu trữ GitHub sentencepiece của Google.
  • Để xem tài liệu về kagglehub, hãy tham khảo README.md trên kho lưu trữ GitHub kagglehub của Kaggle.
  • Tìm hiểu cách sử dụng mô hình Gemma với Vertex AI của Google Cloud.
  • Nếu bạn đang dùng TPU của Google Cloud (phiên bản 3-8 trở lên), đừng quên cập nhật lên gói jax[tpu] mới nhất (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html), khởi động lại thời gian chạy và kiểm tra để đảm bảo rằng các phiên bản jaxjaxlib khớp (!pip list | grep jax). Nhờ đó, RuntimeError có thể phát sinh do phiên bản jaxlibjax không khớp nhau. Để biết thêm hướng dẫn cài đặt JAX, hãy tham khảo tài liệu về JAX.