JAX ve Flax kullanarak CodeGemma ile çıkarım

ai.google.dev adresinde görüntüleme Google Colab'da çalıştırma Kaynağı GitHub'da görüntüleyin

Google DeepMind'ın Gemma modellerine (Gemma Team et al., 2024). CodeGemma, Gemini modellerini oluşturmak için kullanılan aynı araştırma ve teknolojiden oluşturulmuş, hafif ve son teknoloji açık modellerden oluşan bir ailedir.

Gemma önceden eğitilmiş modellerinden devam eden CodeGemma modelleri, Gemma model ailesiyle aynı mimariler kullanılarak 500 ila 1.000 milyardan fazla, çoğunlukla kod jetonu üzerinde daha da eğitilir. Sonuç olarak CodeGemma modelleri, hem tamamlama hem de oluşturma görevlerinde en son kod performansını sağlarken geniş ölçekte güçlü anlama ve akıl yürütme becerilerini korur.

CodeGemma'nın 3 varyantı vardır:

  • 7 milyar kod içeren önceden eğitilmiş bir model
  • 7B talimat ayarlı kod modeli
  • Özellikle kod doldurma ve açık uçlu oluşturma için eğitilmiş bir 2B model.

Bu kılavuzda, kod tamamlama görevi için CodeGemma modelini Flax ile kullanma konusunda size yol gösterilmektedir.

Kurulum

1. CodeGemma için Kaggle erişimini ayarlama

Bu eğitimde yer alan adımları tamamlamak için öncelikle Gemma kurulumu sayfasında yer alan kurulum talimatlarını uygulamanız gerekir. Bu talimatlarda aşağıdakilerin nasıl yapılacağı gösterilmektedir:

  • kaggle.com adresinden CodeGemma'ya erişin.
  • CodeGemma modelini çalıştırmak için yeterli kaynaklara sahip bir Colab çalışma zamanı seçin (T4 GPU'da yeterli bellek yoktur, bunun yerine TPU v2'yi kullanın).
  • Kaggle kullanıcı adı ve API anahtarı oluşturup yapılandırın.

Gemma kurulumunu tamamladıktan sonra, Colab ortamınız için ortam değişkenlerini ayarlayacağınız sonraki bölüme geçin.

2. Ortam değişkenlerini ayarlama

KAGGLE_USERNAME ve KAGGLE_KEY için ortam değişkenlerini ayarlayın. "Erişim izni verilsin mi?" mesajı gösterildiğinde gizli erişim izni vermeyi kabul edin.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma kitaplığını yükleme

Ücretsiz Colab donanım hızlandırması şu anda bu not defterini çalıştırmak için insufficient. Colab Pay As You Go veya Colab Pro kullanıyorsanız donanım hızlandırmayı etkinleştirmek için Düzenle > Not defteri ayarları > A100 GPU'yu seçin > Kaydet'i tıklayın.

Ardından, github.com/google-deepmind/gemma adresinden Google DeepMind gemma kitaplığını yüklemeniz gerekir. "pip'in bağımlılık çözücüsüyle" ilgili bir hata alırsanız genellikle bunu yoksayabilirsiniz.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Kitaplıkları içe aktarma

Bu not defteri, Gemma (sinir ağı katmanlarını oluşturmak için Flax'ı kullanır) ve SentencePiece (token oluşturmak için) kullanır.

import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm

CodeGemma modelini yükleme

CodeGemma modelini, üç bağımsız değişken alan kagglehub.model_download ile yükleyin:

  • handle: Kaggle'daki model herkese açık kullanıcı adı
  • path: (İsteğe bağlı dize) Yerel yol
  • force_download: (İsteğe bağlı boole) Modelin yeniden indirilmesini zorunlu kılar
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Model ağırlıklarının ve dize ayırıcının konumunu kontrol edin, ardından yol değişkenlerini ayarlayın. Söz dizimi ayrıştırıcı dizini, modeli indirdiğiniz ana dizinde bulunurken model ağırlıkları bir alt dizinde olur. Örneğin:

  • spm.model kelime ayırıcı dosyası /LOCAL/PATH/TO/codegemma/flax/2b-pt/3 içinde yer alır.
  • Model kontrol noktası /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt'te olacaktır.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Örnekleme/çıkarım yapma

CodeGemma model kontrol noktasını gemma.params.load_and_format_params yöntemiyle yükleyip biçimlendirin:

params = params_lib.load_and_format_params(CKPT_PATH)

sentencepiece.SentencePieceProcessor kullanılarak oluşturulan CodeGemma kelime dizesini ayırıcısını yükleyin:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

CodeGemma model kontrol noktasından doğru yapılandırmayı otomatik olarak yüklemek için gemma.deprecated.transformer.TransformerConfig değerini kullanın. cache_size bağımsız değişkeni, CodeGemma Transformer önbelleğindeki zaman adımlarının sayısıdır. Ardından, CodeGemma modelini gemma.deprecated.transformer.Transformer (flax.linen.Module'ten devralınan) ile model_2b olarak örneklendirin.

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

gemma.sampler.Sampler ile sampler oluşturun. CodeGemma model kontrol noktasını ve tokenizörü kullanır.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Boşluk doldurma (fim) jetonlarını temsil edecek bazı değişkenler ve istemi ve oluşturulan çıkışı biçimlendirecek bazı yardımcı işlevler oluşturun.

Örneğin, aşağıdaki koda bakalım:

def function(string):
assert function('asdf') == 'fdsa'

İddianın True olması için function değerini doldurmak istiyoruz. Bu durumda ön ek şu şekilde olur:

"def function(string):\n"

Son ek ise şöyle olur:

"assert function('asdf') == 'fdsa'"

Ardından bunu PREFIX-SUFFIX-MIDDLE (ÖNEK-SONEK-ORTA) şeklinde bir istem olarak biçimlendiririz (doldurulması gereken orta bölüm her zaman istemin sonundadır):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

İstem oluşturun ve çıkarım yapın. before ön ek metnini ve after son ek metnini belirtin ve format_completion prompt yardımcı işlevini kullanarak biçimlendirilmiş istemi oluşturun.

total_generation_steps değerini (yanıt oluşturulurken gerçekleştirilen adım sayısı. Bu örnekte, ana makine belleğini korumak için 100 kullanılır) değiştirebilirsiniz.

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

Daha fazla bilgi