Gemma à PyTorch

Afficher sur ai.google.dev Exécuter dans Google Colab Afficher la source sur GitHub

Voici une démonstration rapide de l'exécution de l'inférence Gemma dans PyTorch. Pour en savoir plus, consultez le dépôt GitHub de l'implémentation officielle de PyTorch sur cette page.

Remarque:

  • L'environnement d'exécution Python sans frais pour le CPU Colab et l'environnement d'exécution Python pour le GPU T4 sont suffisants pour exécuter les modèles Gemma 2B et les modèles quantifiés int8 de 7 milliards.
  • Pour les cas d'utilisation avancés d'autres GPU ou TPU, veuillez consulter le fichier README.md dans le dépôt officiel.

1. Configurer l'accès à Kaggle pour Gemma

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de configuration de la section Configuration de Gemma, qui vous explique comment effectuer les opérations suivantes:

  • Accédez à Gemma sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab disposant de ressources suffisantes pour exécuter le modèle Gemma.
  • Générez et configurez un nom d'utilisateur et une clé API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, où vous définirez des variables d'environnement pour votre environnement Colab.

2. Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY. Lorsque vous êtes invité à accorder l'accès, acceptez de fournir l'accès au secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installer des dépendances

pip install -q -U torch immutabledict sentencepiece

Télécharger les pondérations du modèle

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'
import os
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Télécharger l'implémentation du modèle

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

Configurer le modèle

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

Exécuter une inférence

Vous trouverez ci-dessous des exemples de génération en mode chat et de génération avec plusieurs requêtes.

Les modèles Gemma adaptés aux instructions ont été entraînés avec un formateur spécifique qui annoté les exemples d'ajustement des instructions avec des informations supplémentaires, à la fois pendant l'entraînement et l'inférence. Les annotations (1) indiquent les rôles dans une conversation et (2) définissent les tours de piste.

Les jetons d'annotation pertinents sont les suivants:

  • user: tour de l'utilisateur
  • model: rotation du modèle
  • <start_of_turn>: début de tour de parole
  • <end_of_turn><eos>: fin du tour de dialogue

Pour en savoir plus sur la mise en forme des requêtes pour des instructions sur les modèles Gemma réglés, cliquez ici.

L'extrait de code suivant montre comment mettre en forme une requête pour un modèle Gemma avec réglage des instructions à l'aide de modèles de chat utilisateur et de modèle dans une conversation à plusieurs tours.

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

En savoir plus

Maintenant que vous avez appris à utiliser Gemma dans Pytorch, vous pouvez explorer les nombreuses autres fonctionnalités de Gemma sur ai.google.dev/gemma. Consultez également les ressources associées suivantes: