JAX ve Flax kullanarak RecurrentGemma ile çıkarım

ai.google.dev'de görüntüleyin Google Colab'de çalıştır Vertex AI'da aç Kaynağı GitHub'da görüntüle

Bu eğitimde, JAX (yüksek performanslı sayısal bilgi işlem kitaplığı), Flax (JAX tabanlı nöral ağ kitaplığı), Orbax (checkpointing JAXie_de1 benzeri jeton kitaplığı1 gibi Sentapax tabanlı kitaplık) kitaplığı ile yazılmış Google DeepMind'ın recurrentgemma kitaplığını kullanarak RecurrentGemma 2B Talimat modeliyle nasıl temel örnekleme/çıkarım yapılacağı gösterilmektedir.SentencePiece Flax, doğrudan bu not defterinde kullanılmasa da Gemma ve RecurrentGemma'yı (griffin modeli) oluşturmak için Flax kullanılmıştır.

Bu not defteri, Google Colab'de T4 GPU ile çalışabilir (Düzenle > Not defteri ayarları > Donanım hızlandırıcı'nın altında T4 GPU'yu seçin).

Kurulum

Aşağıdaki bölümlerde, model erişimi, API anahtarı alma ve not defteri çalışma zamanını yapılandırma da dahil olmak üzere bir not defterini RecurrentGemma modelini kullanmak üzere hazırlama adımları açıklanmaktadır.

Gemma için Kaggle erişimini ayarlama

Bu eğiticiyi tamamlamak için önce birkaç istisna dışında Gemma kurulumuna benzer kurulum talimatlarını uygulamanız gerekir:

  • kaggle.com adresinden RecurrentGemma'ya (Gemma yerine) erişin.
  • RecurrentGemma modelini çalıştırmak için yeterli kaynağa sahip bir Colab çalışma zamanı seçin.
  • Kaggle kullanıcı adı ve API anahtarı oluşturup yapılandırın.

RecurrentGemma kurulumunu tamamladıktan sonra, Colab ortamınız için ortam değişkenlerini ayarlayacağınız bir sonraki bölüme geçin.

Ortam değişkenlerini ayarlama

KAGGLE_USERNAME ve KAGGLE_KEY için ortam değişkenlerini ayarlayın. "Erişim izni verilsin mi?" sorusuyla karşılaştığınızda gizli erişim izni vermeyi kabul edin.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

recurrentgemma kitaplığını yükle

Bu not defteri, ücretsiz Colab GPU kullanmaya odaklanmıştır. Donanım hızlandırmayı etkinleştirmek için Edit (Düzenle) > Not defteri ayarları > T4 GPU'yu seçin > Kaydet'i seçin.

Ardından, github.com/google-deepmind/recurrentgemma üzerinden Google DeepMind recurrentgemma kitaplığını yüklemeniz gerekir. "pip'in bağımlılık çözümleyicisi" hatası alırsanız genellikle bunu göz ardı edebilirsiniz.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

RecurrentGemma modelini yükleme ve hazırlama

  1. Üç bağımsız değişken alan kagglehub.model_download ile RecurrentGemma modelini yükleyin:
  • handle: Kaggle'ın model tutma yeri
  • path: (İsteğe bağlı dize) Yerel yol
  • force_download: (İsteğe bağlı boole) Modeli yeniden indirmeye zorlar
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Model ağırlıklarının ve belirteç oluşturucunun konumunu kontrol edin, ardından yol değişkenlerini ayarlayın. Jeton oluşturucu dizini, modeli indirdiğiniz ana dizinde, model ağırlıkları ise bir alt dizinde yer alır. Örneğin:
  • tokenizer.model dosyası /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 klasöründe yer alır.)
  • Model kontrol noktası /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it konumunda olacaktır.)
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Örnekleme/çıkarım gerçekleştirme

  1. recurrentgemma.jax.load_parameters yöntemiyle RecurrentGemma modeli kontrol noktasını yükleyin. "single_device" olarak ayarlanan sharding bağımsız değişkeni, tüm model parametrelerini tek bir cihaza yükler.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. sentencepiece.SentencePieceProcessor kullanılarak oluşturulan RecurrentGemma model jeton oluşturucuyu yükleyin:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. RecurrentGemma model kontrol noktasından doğru yapılandırmayı otomatik olarak yüklemek için recurrentgemma.GriffinConfig.from_flax_params_or_variables kullanın. Ardından, Griffin modelini recurrentgemma.jax.Griffin ile örneklendirin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. RecurrentGemma model kontrol noktası/ağırlıkları ve jeton oluşturucunun üzerine recurrentgemma.jax.Sampler ile bir sampler oluşturun:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. prompt dilinde bir istem yazıp çıkarım yapın. total_generation_steps üzerinde ince ayar yapabilirsiniz (yanıt oluşturulurken gerçekleştirilen adım sayısı; bu örnekte ana makine belleğini korumak için 50 kullanılır).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Daha fazla bilgi