JAX ve Flax kullanarak CodeGemma ile çıkarım

ai.google.dev'de görüntüleyin Google Colab'de çalıştır Kaynağı GitHub'da görüntüle

Google DeepMind'ın Gemma modellerine dayanan bir açık kod modelleri koleksiyonu olan CodeGemma'yı (Gemma Ekibi ve diğerleri, 2024). CodeGemma, Gemini modellerini oluşturmak için kullanılan araştırma ve teknolojiyle geliştirilmiş hafif, son teknoloji ürünü açık modeller ailesidir.

Önceden eğitilmiş Gemma modellerinden devam eden CodeGemma modelleri, temelde 500 ila 1.000 milyardan fazla kod kullanılarak daha da eğitilir. model ailesi ile aynı mimarilere sahip. Sonuç olarak, CodeGemma modelleri hem tamamlama hem de işlem tamamlamada ve ürün oluşturmak için kullanacağınız, anlama ve akıl yürütme becerileridir.

CodeGemma'nın 3 varyantı vardır:

  • Önceden eğitilmiş 7B kodlu model
  • Talimat ayarlı 7B kod modeli
  • Kod doldurma ve açık uçlu oluşturma için özel olarak eğitilmiş 2B model.

Bu kılavuz, bir kod tamamlama görevi için CodeGemma modelini Flax ile kullanma konusunda size yol gösterir.

Kurulum

1. CodeGemma için Kaggle erişimini ayarlama

Bu eğiticiyi tamamlamak için önce Gemma kurulumu'ndaki kurulum talimatlarını uygulamanız gerekir. Bu talimatlarda, aşağıdakileri nasıl yapacağınızı öğrenebilirsiniz:

  • kaggle.com adresinden CodeGemma'ya erişin.
  • CodeGemma modelini çalıştırmak için yeterli kaynağa sahip bir Colab çalışma zamanı seçin (T4 GPU'nun yeterli belleği yok, bunun yerine TPU v2'yi kullanın).
  • Kaggle kullanıcı adı ve API anahtarı oluşturup yapılandırın.

Gemma kurulumunu tamamladıktan sonra bir sonraki bölüme geçin. Burada, Colab ortamınız için ortam değişkenlerini ayarlayabilirsiniz.

2. Ortam değişkenlerini ayarlama

KAGGLE_USERNAME ve KAGGLE_KEY için ortam değişkenlerini ayarlayın. "Erişim izni verilsin mi?" sorusuyla karşılaştığınızda gizli erişim izni vermeyi kabul edin.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma kitaplığını yükle

Ücretsiz Colab donanım hızlandırma özelliği şu anda bu not defterini çalıştırmak için yetersiz. Colab Pay As You Go veya Colab Pro kullanıyorsanız Düzenle'yi tıklayın > Not defteri ayarları > A100 GPU > seçeneğini belirleyin Donanım hızlandırmayı etkinleştirmek için Kaydet'i seçin.

Ardından, github.com/google-deepmind/gemma üzerinden Google DeepMind gemma kitaplığını yüklemeniz gerekir. "pip'in bağımlılık çözümleyicisi" hatası alırsanız genellikle bunu göz ardı edebilirsiniz.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Kitaplıkları içe aktar

Bu not defterinde Gemma (nöral ağ katmanlarını oluşturmak için Flax kullanılır) ve SentencePiece (tokenleştirme için) kullanılır.

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

CodeGemma modelini yükleme

CodeGemma modelini, üç bağımsız değişken alan kagglehub.model_download ile yükleyin:

  • handle: Kaggle'ın model tutma yeri
  • path: (İsteğe bağlı dize) Yerel yol
  • force_download: (İsteğe bağlı boole) Modeli yeniden indirmeye zorlar
ziyaret edin.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Model ağırlıklarının ve tokenleştiricinin konumunu kontrol edin, ardından yol değişkenlerini ayarlayın. Jeton oluşturucu dizini, modeli indirdiğiniz ana dizinde, model ağırlıkları ise bir alt dizinde yer alır. Örneğin:

  • spm.model jeton oluşturucu dosyası /LOCAL/PATH/TO/codegemma/flax/2b-pt/3 içinde olacak
  • Model kontrol noktası /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt içinde olacak
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Örnekleme/çıkarım gerçekleştirme

CodeGemma modeli kontrol noktasını gemma.params.load_and_format_params yöntemiyle yükleyin ve biçimlendirin:

params = params_lib.load_and_format_params(CKPT_PATH)

sentencepiece.SentencePieceProcessor kullanılarak oluşturulan CodeGemma jeton oluşturucuyu yükleyin:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

CodeGemma model kontrol noktasından doğru yapılandırmayı otomatik olarak yüklemek için gemma.transformer.TransformerConfig kodunu kullanın. cache_size bağımsız değişkeni, CodeGemma Transformer önbelleğindeki zaman adımlarının sayısıdır. Ardından, gemma.transformer.Transformer (flax.linen.Module öğesinden devralır) ile CodeGemma modelini model_2b olarak örneklendirin.

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

gemma.sampler.Sampler ile bir sampler oluşturun. CodeGemma model kontrol noktası ve jeton oluşturucuyu kullanır.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Ortadaki doldur (fim) jetonlarını temsil edecek bazı değişkenler oluşturun. İstemi ve oluşturulan çıkışı biçimlendirmek için bazı yardımcı işlevler oluşturun.

Örneğin, aşağıdaki kodu inceleyelim:

def function(string):
assert function('asdf') == 'fdsa'

Onaylamanın True değerini koruması için function öğesini doldurmak istiyoruz. Bu durumda, ön ek şöyle olur:

"def function(string):\n"

Sonek de şöyle olur:

"assert function('asdf') == 'fdsa'"

Daha sonra bunu bir istem olarak PREFIX-SUFFIX-MIDDLE olarak biçimlendiririz (doldurulması gereken orta bölüm her zaman istemin sonundadır):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

İstem oluşturun ve çıkarım yapın. before ön ekini ve after metni son ekini belirtin ve format_completion prompt yardımcı işlevini kullanarak biçimlendirilmiş istemi oluşturun.

total_generation_steps üzerinde ince ayar yapabilirsiniz (yanıt oluşturulurken gerçekleştirilen adım sayısı; bu örnekte ana makine belleğini korumak için 100 kullanılır).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

Daha fazla bilgi