Посмотреть на ai.google.dev | Запустить в Google Colab | Открыть в Vertex AI | Посмотреть исходный код на GitHub |
В этом руководстве показано, как выполнить базовую выборку/вывод с помощью модели RecurrentGemma 2B Instruct с использованием библиотеки recurrentgemma
Google DeepMind , написанной с использованием JAX (библиотека высокопроизводительных численных вычислений), Flax (библиотека нейронных сетей на основе JAX), Orbax (библиотека нейронных сетей на основе JAX). библиотека на основе JAX для обучения утилитам, таким как создание контрольных точек) и SentencePiece (библиотека токенизатора/детокенизатора). Хотя Flax не используется напрямую в этом блокноте, Flax использовался для создания Gemma и RecurrentGemma (модель Гриффина).
Этот ноутбук может работать в Google Colab с графическим процессором T4 (перейдите в «Редактирование» > «Настройки ноутбука» > в разделе «Аппаратный ускоритель» выберите «T4 GPU» ).
Настраивать
В следующих разделах описываются шаги по подготовке записной книжки для использования модели RecurrentGemma, включая доступ к модели, получение ключа API и настройку среды выполнения записной книжки.
Настройте доступ к Kaggle для Джеммы
Чтобы выполнить это руководство, сначала необходимо следовать инструкциям по настройке, аналогичным настройке Gemma, за некоторыми исключениями:
- Получите доступ к RecurrentGemma (вместо Gemma) на kaggle.com .
- Выберите среду выполнения Colab с достаточными ресурсами для запуска модели RecurrentGemma.
- Создайте и настройте имя пользователя Kaggle и ключ API.
После завершения настройки RecurrentGemma перейдите к следующему разделу, где вы установите переменные среды для вашей среды Colab.
Установить переменные среды
Установите переменные среды для KAGGLE_USERNAME
и KAGGLE_KEY
. При появлении запроса «Предоставить доступ?» сообщения, согласитесь предоставить секретный доступ.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Установите библиотеку recurrentgemma
В этом ноутбуке основное внимание уделяется использованию бесплатного графического процессора Colab. Чтобы включить аппаратное ускорение, нажмите «Редактировать» > «Настройки ноутбука» > «Выбрать графический процессор T4» > «Сохранить» .
Далее вам необходимо установить библиотеку recurrentgemma
Google DeepMind с github.com/google-deepmind/recurrentgemma
. Если вы получаете сообщение об ошибке «преобразователь зависимостей pip», обычно вы можете игнорировать его.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Загрузите и подготовьте модель RecurrentGemma.
- Загрузите модель RecurrentGemma с помощью
kagglehub.model_download
, который принимает три аргумента:
-
handle
: ручка модели от Kaggle. -
path
: (Необязательная строка) Локальный путь -
force_download
: (Необязательное логическое значение) Принудительно повторно загрузить модель.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Проверьте расположение весов модели и токенизатора, затем установите переменные пути. Каталог токенизатора будет находиться в основном каталоге, в который вы загрузили модель, а веса модели будут находиться в подкаталоге. Например:
- Файл
tokenizer.model
будет находиться в/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
). - Контрольная точка модели будет находиться в
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Выполнить выборку/вывод
- Загрузите контрольную точку модели RecurrentGemma с помощью метода
recurrentgemma.jax.load_parameters
. Аргументsharding
, установленный на"single_device"
загружает все параметры модели на одном устройстве.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Загрузите токенизатор модели RecurrentGemma, созданный с помощью
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Чтобы автоматически загрузить правильную конфигурацию из контрольной точки модели RecurrentGemma, используйте
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Затем создайте экземпляр модели Гриффина с помощьюrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Создайте
sampler
с помощьюrecurrentgemma.jax.Sampler
поверх контрольной точки/весов модели RecurrentGemma и токенизатора:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Напишите подсказку в
prompt
и выполните вывод. Вы можете настроитьtotal_generation_steps
(количество шагов, выполняемых при генерации ответа — в этом примере используется50
для экономии памяти хоста).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Узнать больше
- Вы можете узнать больше о библиотеке
recurrentgemma
Google DeepMind на GitHub , которая содержит строки документации методов и модулей, которые вы использовали в этом руководстве, таких какrecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
иrecurrentgemma.jax.Sampler
. - Следующие библиотеки имеют собственные сайты документации: core JAX , Flax и Orbax .
- Документацию по токенизатору/детокенизатору
sentencepiece
можно найти в репозитории Googlesentencepiece
на GitHub . - Документацию по
kagglehub
можно найтиREADME.md
в репозиторииkagglehub
на GitHub . - Узнайте, как использовать модели Gemma с Google Cloud Vertex AI .
- Ознакомьтесь с документом RecurrentGemma: Moving Past Transformers for Efficient Open Language Models от Google DeepMind.
- Прочитайте статью Griffin: Mixing Gated Linear Recurrens with Local Attention for Efficient Language Models от GoogleDeepMind, чтобы узнать больше об архитектуре модели, используемой RecurrentGemma.