Inferensi dengan CodeGemma menggunakan JAX dan Flax

Lihat di ai.google.dev Berjalan di Google Colab Lihat sumber di GitHub

Kami mempersembahkan CodeGemma, yakni kumpulan model kode terbuka berdasarkan model Gemma Google DeepMind (Gemma Team et al., 2024). CodeGemma adalah kelompok model terbuka yang ringan dan canggih, dibangun dari riset dan teknologi yang digunakan untuk membuat model Gemini.

Berlanjut dari model Gemma yang telah dilatih sebelumnya, model CodeGemma dilatih lebih lanjut menggunakan lebih dari 500 hingga 1000 miliar token terutama kode, menggunakan arsitektur yang sama dengan kelompok model Gemma. Hasilnya, model CodeGemma mencapai performa kode yang canggih dalam kedua tahap penyelesaian dan pembuatan tugas, sambil tetap mempertahankan dan keterampilan penalaran dalam skala besar.

CodeGemma memiliki 3 varian:

  • Model terlatih kode 7B
  • Model kode yang disesuaikan dengan instruksi sebesar 7B
  • Model 2B, dilatih khusus untuk pengisian kode dan pembuatan open-ended.

Panduan ini memandu Anda menggunakan model CodeGemma dengan Flax untuk tugas penyelesaian kode.

Penyiapan

1. Menyiapkan akses Kaggle untuk CodeGemma

Untuk menyelesaikan tutorial ini, pertama-tama Anda harus mengikuti petunjuk penyiapan di penyiapan Gemma, yang menunjukkan cara melakukan hal berikut:

  • Dapatkan akses ke CodeGemma di kaggle.com.
  • Pilih runtime Colab dengan resource yang memadai (GPU T4 memiliki memori yang tidak cukup, gunakan TPU v2) untuk menjalankan model CodeGemma.
  • Membuat dan mengkonfigurasi nama pengguna dan kunci API Kaggle.

Setelah Anda menyelesaikan penyiapan Gemma, lanjutkan ke bagian berikutnya, untuk menetapkan variabel lingkungan untuk lingkungan Colab Anda.

2. Menetapkan variabel lingkungan

Menetapkan variabel lingkungan untuk KAGGLE_USERNAME dan KAGGLE_KEY. Saat melihat dialog "Berikan akses?", pesan, setuju untuk memberikan akses rahasia.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Menginstal library gemma

Akselerasi hardware Colab gratis saat ini tidak cukup untuk menjalankan notebook ini. Jika Anda menggunakan Colab Pay As You Go atau Colab Pro, klik Edit > Setelan notebook > Pilih GPU A100 > Simpan untuk mengaktifkan akselerasi hardware.

Selanjutnya, Anda perlu menginstal library gemma Google DeepMind dari github.com/google-deepmind/gemma. Jika mendapatkan error tentang "resolver dependensi pip", Anda biasanya bisa mengabaikannya.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Mengimpor library

Notebook ini menggunakan Gemma (yang menggunakan Flax untuk membangun lapisan jaringan neuralnya), dan SentencePiece (untuk tokenisasi).

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

Memuat model CodeGemma

Muat model CodeGemma dengan kagglehub.model_download, yang menggunakan tiga argumen:

  • handle: Handle model dari Kaggle
  • path: (String opsional) Jalur lokal
  • force_download: (Boolean opsional) Memaksa untuk mendownload ulang model
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Periksa lokasi bobot model dan tokenizer, lalu tetapkan variabel jalur. Direktori tokenizer akan berada di direktori utama tempat Anda mendownload model, sedangkan bobot model akan berada di sub-direktori. Contoh:

  • File tokenizer spm.model akan berada di /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • Checkpoint model akan berada di /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Melakukan sampling/inferensi

Muat dan format checkpoint model CodeGemma dengan metode gemma.params.load_and_format_params:

params = params_lib.load_and_format_params(CKPT_PATH)

Muat tokenizer CodeGemma, yang dibuat menggunakan sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

Untuk otomatis memuat konfigurasi yang benar dari checkpoint model CodeGemma, gunakan gemma.transformer.TransformerConfig. Argumen cache_size adalah jumlah langkah waktu dalam cache Transformer CodeGemma. Setelah itu, buat instance model CodeGemma sebagai model_2b dengan gemma.transformer.Transformer (yang diturunkan dari flax.linen.Module).

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

Buat sampler dengan gemma.sampler.Sampler. Token ini menggunakan checkpoint model CodeGemma dan tokenizer.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Buat beberapa variabel untuk mewakili token fill-in-the-middle (fim) dan buat beberapa fungsi bantuan untuk memformat prompt dan output yang dihasilkan.

Sebagai contoh, mari kita lihat kode berikut:

def function(string):
assert function('asdf') == 'fdsa'

Kita ingin mengisi function agar pernyataan memiliki True. Dalam hal ini, awalannya adalah:

"def function(string):\n"

Dan akhirannya adalah:

"assert function('asdf') == 'fdsa'"

Kemudian kita memformatnya menjadi perintah sebagai PREFIX-SUFFIX-MIDDLE (bagian tengah yang harus diisi selalu berada di akhir perintah):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

Membuat prompt dan melakukan inferensi. Tentukan teks awalan before dan teks akhiran after, lalu buat perintah berformat menggunakan fungsi bantuan format_completion prompt.

Anda dapat menyesuaikan total_generation_steps (jumlah langkah yang dilakukan saat membuat respons — contoh ini menggunakan 100 untuk menghemat memori host).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

Pelajari lebih lanjut