Вывод с помощью RecurrentGemma с использованием JAX и Flax

Посмотреть на ai.google.dev Запустить в Google Colab Открыть в Vertex AI Посмотреть исходный код на GitHub

В этом руководстве показано, как выполнить базовую выборку/вывод с помощью модели RecurrentGemma 2B Instruct с использованием библиотеки recurrentgemma Google DeepMind , написанной с использованием JAX (библиотека высокопроизводительных численных вычислений), Flax (библиотека нейронных сетей на основе JAX), Orbax (библиотека нейронных сетей на основе JAX). библиотека на основе JAX для обучения утилитам, таким как создание контрольных точек) и SentencePiece (библиотека токенизатора/детокенизатора). Хотя Flax не используется напрямую в этом блокноте, Flax использовался для создания Gemma и RecurrentGemma (модель Гриффина).

Этот ноутбук может работать в Google Colab с графическим процессором T4 (перейдите в «Редактирование» > «Настройки ноутбука» > в разделе «Аппаратный ускоритель» выберите «T4 GPU» ).

Настраивать

В следующих разделах описываются шаги по подготовке записной книжки для использования модели RecurrentGemma, включая доступ к модели, получение ключа API и настройку среды выполнения записной книжки.

Настройте доступ к Kaggle для Джеммы

Чтобы выполнить это руководство, сначала необходимо следовать инструкциям по настройке, аналогичным настройке Gemma, за некоторыми исключениями:

  • Получите доступ к RecurrentGemma (вместо Gemma) на kaggle.com .
  • Выберите среду выполнения Colab с достаточными ресурсами для запуска модели RecurrentGemma.
  • Создайте и настройте имя пользователя Kaggle и ключ API.

После завершения настройки RecurrentGemma перейдите к следующему разделу, где вы установите переменные среды для вашей среды Colab.

Установить переменные среды

Установите переменные среды для KAGGLE_USERNAME и KAGGLE_KEY . При появлении запроса «Предоставить доступ?» сообщения, согласитесь предоставить секретный доступ.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Установите библиотеку recurrentgemma

В этом ноутбуке основное внимание уделяется использованию бесплатного графического процессора Colab. Чтобы включить аппаратное ускорение, нажмите «Редактировать» > «Настройки ноутбука» > «Выбрать графический процессор T4» > «Сохранить» .

Далее вам необходимо установить библиотеку recurrentgemma Google DeepMind с github.com/google-deepmind/recurrentgemma . Если вы получаете сообщение об ошибке «преобразователь зависимостей pip», обычно вы можете игнорировать его.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Загрузите и подготовьте модель RecurrentGemma.

  1. Загрузите модель RecurrentGemma с помощью kagglehub.model_download , который принимает три аргумента:
  • handle : ручка модели от Kaggle.
  • path : (Необязательная строка) Локальный путь
  • force_download : (Необязательное логическое значение) Принудительно повторно загрузить модель.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Проверьте расположение весов модели и токенизатора, затем установите переменные пути. Каталог токенизатора будет находиться в основном каталоге, в который вы загрузили модель, а веса модели будут находиться в подкаталоге. Например:
  • Файл tokenizer.model будет находиться в /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 ).
  • Контрольная точка модели будет находиться в /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it ).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Выполнить выборку/вывод

  1. Загрузите контрольную точку модели RecurrentGemma с помощью метода recurrentgemma.jax.load_parameters . Аргумент sharding , установленный на "single_device" загружает все параметры модели на одном устройстве.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Загрузите токенизатор модели RecurrentGemma, созданный с помощью sentencepiece.SentencePieceProcessor :
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Чтобы автоматически загрузить правильную конфигурацию из контрольной точки модели RecurrentGemma, используйте recurrentgemma.GriffinConfig.from_flax_params_or_variables . Затем создайте экземпляр модели Гриффина с помощью recurrentgemma.jax.Griffin .
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Создайте sampler с помощью recurrentgemma.jax.Sampler поверх контрольной точки/весов модели RecurrentGemma и токенизатора:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Напишите подсказку в prompt и выполните вывод. Вы можете настроить total_generation_steps (количество шагов, выполняемых при генерации ответа — в этом примере используется 50 для экономии памяти хоста).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Узнать больше