Джемма в PyTorch

Посмотреть на ai.google.dev Запустить в Google Colab Посмотреть исходный код на GitHub

Это быстрая демонстрация выполнения вывода Gemma в PyTorch. Для получения более подробной информации посетите репозиторий официальной реализации PyTorch на Github здесь .

Обратите внимание, что :

  • Бесплатная среда выполнения Colab CPU Python и среда выполнения T4 GPU Python достаточны для запуска моделей Gemma 2B и квантованных моделей 7B int8.
  • Дополнительные варианты использования других графических процессоров или TPU см. в README.md в официальном репозитории.

Доступ к Кагглу

Чтобы войти в Kaggle, вы можете либо сохранить файл учетных данных kaggle.json в ~/.kaggle/kaggle.json , либо запустить следующее в среде Colab. Дополнительную информацию см. в документации пакета kagglehub .

import kagglehub

kagglehub.login()

Установить зависимости

pip install -q -U torch immutabledict sentencepiece

Скачать вес модели

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'
import os

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Скачать реализацию модели

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
import sys

sys.path.append('gemma_pytorch')
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

Настройка модели

import torch

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

Выполнить вывод

Ниже приведены примеры генерации в режиме чата и генерации с несколькими запросами.

Модели Gemma, настроенные с помощью инструкций, были обучены с помощью специального форматтера, который аннотирует примеры настройки инструкций дополнительной информацией как во время обучения, так и в процессе вывода. Аннотации (1) обозначают роли в разговоре, а (2) обозначают повороты в разговоре. Ниже мы показываем пример фрагмента кода для форматирования приглашения модели с использованием шаблонов чата пользователя и модели в многоходовом разговоре. Соответствующие токены:

  • user : очередь пользователя
  • model : поворот модели
  • <start_of_turn> : начало хода диалога.
  • <end_of_turn> : конец хода диалога.

О форматировании Gemma для настройки инструкций и системных инструкциях читайте здесь .

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn>
<start_of_turn>model
California.<end_of_turn>
<start_of_turn>user
What can I do in California?<end_of_turn>
<start_of_turn>model
"* **Visit the Golden Gate Bridge and Alcatraz Island in San Francisco.**\n* **Head to Yosemite National Park and marvel at nature's beauty.**\n* **Explore the bustling metropolis of Los Angeles.**\n* **Relax on the pristine beaches of Santa Monica or Malibu.**\n* **Go whale watching in Monterey Bay.**\n* **Discover the charming coastal towns of Monterey Bay and Carmel-by-the-Sea.**\n* **Visit Disneyland and Disney California Adventure in Anaheim.**\n*"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=60,
)
['\n\nThe fingers dance on the keys,\nA symphony of thoughts and dreams.\nThe mind, a canvas yet uncouth,\nScribbling its secrets in the night.\n\nThe ink, a whispered voice from deep,\nA language ancient, never to sleep.\nEach stroke an echo of']

Узнать больше

Теперь, когда вы узнали, как использовать Gemma в Pytorch, вы можете изучить множество других возможностей Gemma в ai.google.dev/gemma . См. также другие связанные ресурсы: