Inférence avec CodeGemma à l'aide de JAX et Flax

Afficher sur ai.google.dev Exécuter dans Google Colab Consulter le code source sur GitHub

Nous présentons CodeGemma, un ensemble de modèles de code ouvert basés sur les modèles Gemma de Google DeepMind (Gemma Team et al., 2024). CodeGemma est une famille de modèles ouverts légers et de pointe, créés à partir des recherches et des technologies utilisées pour créer les modèles Gemini.

À partir des modèles pré-entraînés Gemma, les modèles CodeGemma sont entraînés sur plus de 500 à 1 000 milliards de jetons, principalement du code, à l'aide de les mêmes architectures que la famille de modèles Gemma. Par conséquent, les modèles CodeGemma atteignent des performances de code optimales lors de la finalisation de production et de génération, tout en maintenant de compréhension et de raisonnement à grande échelle.

CodeGemma comporte trois variantes:

  • Un modèle pré-entraîné avec du code de 7 milliards
  • Un modèle de code réglé avec les instructions 7B
  • Un modèle 2B, entraîné spécifiquement pour le remplissage de code et la génération ouverte.

Ce guide vous explique comment utiliser le modèle CodeGemma avec Flax pour une tâche de saisie de code.

Configuration

1. Configurer l'accès à Kaggle pour CodeGemma

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de la page Configuration de Gemma, qui expliquent comment effectuer les opérations suivantes:

  • Accédez à CodeGemma sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab avec suffisamment de ressources (le GPU T4 est insuffisant, utilisez plutôt le TPU v2) pour exécuter le modèle CodeGemma.
  • Générez et configurez un nom d'utilisateur et une clé API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, dans laquelle vous allez définir des variables d'environnement pour votre environnement Colab.

2. Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY. Lorsque le message "Accorder l'accès ?" s'affiche, messages, acceptez de fournir un accès secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Installer la bibliothèque gemma

L'accélération matérielle sans frais de Colab est actuellement insuffisante pour exécuter ce notebook. Si vous utilisez le paiement à l'usage Colab ou Colab Pro, cliquez sur Modifier > Paramètres du notebook > Sélectionnez GPU A100 > Enregistrez pour activer l'accélération matérielle.

Vous devez ensuite installer la bibliothèque Google DeepMind gemma à partir de github.com/google-deepmind/gemma. Si vous obtenez une erreur concernant le "résolveur de dépendances de pip", vous pouvez généralement l'ignorer.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Importer des bibliothèques

Ce notebook utilise Gemma (qui utilise Flax pour créer les couches de son réseau de neurones) et SentencePiece (pour la tokenisation).

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

Charger le modèle CodeGemma

Chargez le modèle CodeGemma avec kagglehub.model_download, qui utilise trois arguments:

  • handle: le gestionnaire de modèle de Kaggle
  • path (chaîne facultative) : chemin d'accès local
  • force_download: (valeur booléenne facultative) force le téléchargement à nouveau du modèle.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Vérifiez l'emplacement des pondérations du modèle et de la fonction de tokenisation, puis définissez les variables de chemin. Le répertoire de tokenisation se trouve dans le répertoire principal dans lequel vous avez téléchargé le modèle, tandis que les pondérations du modèle sont dans un sous-répertoire. Exemple :

  • Le fichier de tokenisation spm.model se trouvera dans /LOCAL/PATH/TO/codegemma/flax/2b-pt/3.
  • Le point de contrôle du modèle se trouvera dans /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Effectuer un échantillonnage/une inférence

Chargez et formatez le point de contrôle du modèle CodeGemma à l'aide de la méthode gemma.params.load_and_format_params:

params = params_lib.load_and_format_params(CKPT_PATH)

Chargez la fonction de tokenisation CodeGemma, créée à l'aide de sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

Pour charger automatiquement la configuration appropriée à partir du point de contrôle du modèle CodeGemma, utilisez gemma.transformer.TransformerConfig. L'argument cache_size correspond au nombre de pas de temps dans le cache Transformer de CodeGemma. Instanciez ensuite le modèle CodeGemma en tant que model_2b avec gemma.transformer.Transformer (qui hérite de flax.linen.Module).

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

Créez un sampler avec gemma.sampler.Sampler. Elle utilise le point de contrôle du modèle CodeGemma et la fonction de tokenisation.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Créez des variables pour représenter les jetons de remplissage au milieu (Fim) et créez des fonctions d'assistance pour mettre en forme la requête et la sortie générée.

Examinons par exemple le code suivant:

def function(string):
assert function('asdf') == 'fdsa'

Nous aimerions renseigner function pour que l'assertion contienne True. Dans ce cas, le préfixe serait:

"def function(string):\n"

Et le suffixe serait:

"assert function('asdf') == 'fdsa'"

Nous formatons ensuite ceci en requête de type PREFIX-SUFFIX-MIDDLE (la section du milieu à remplir se trouve toujours à la fin de la requête):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

Créer une requête et effectuer des inférences Spécifiez le texte du préfixe before et le texte du suffixe after, puis générez la requête mise en forme à l'aide de la fonction d'assistance format_completion prompt.

Vous pouvez modifier total_generation_steps (le nombre d'étapes effectuées lors de la génération d'une réponse. Cet exemple utilise 100 pour préserver la mémoire de l'hôte).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

En savoir plus