Afficher sur ai.google.dev | Exécuter dans Google Colab | Consulter le code source sur GitHub |
Nous présentons CodeGemma, un ensemble de modèles de code ouvert basés sur les modèles Gemma de Google DeepMind (Gemma Team et al., 2024). CodeGemma est une famille de modèles ouverts légers et de pointe, créés à partir des recherches et des technologies utilisées pour créer les modèles Gemini.
À partir des modèles pré-entraînés Gemma, les modèles CodeGemma sont entraînés sur plus de 500 à 1 000 milliards de jetons, principalement du code, à l'aide de les mêmes architectures que la famille de modèles Gemma. Par conséquent, les modèles CodeGemma atteignent des performances de code optimales lors de la finalisation de production et de génération, tout en maintenant de compréhension et de raisonnement à grande échelle.
CodeGemma comporte trois variantes:
- Un modèle pré-entraîné avec du code de 7 milliards
- Un modèle de code réglé avec les instructions 7B
- Un modèle 2B, entraîné spécifiquement pour le remplissage de code et la génération ouverte.
Ce guide vous explique comment utiliser le modèle CodeGemma avec Flax pour une tâche de saisie de code.
Configuration
1. Configurer l'accès à Kaggle pour CodeGemma
Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de la page Configuration de Gemma, qui expliquent comment effectuer les opérations suivantes:
- Accédez à CodeGemma sur kaggle.com.
- Sélectionnez un environnement d'exécution Colab avec suffisamment de ressources (le GPU T4 est insuffisant, utilisez plutôt le TPU v2) pour exécuter le modèle CodeGemma.
- Générez et configurez un nom d'utilisateur et une clé API Kaggle.
Une fois la configuration de Gemma terminée, passez à la section suivante, dans laquelle vous allez définir des variables d'environnement pour votre environnement Colab.
2. Définir des variables d'environnement
Définissez les variables d'environnement pour KAGGLE_USERNAME
et KAGGLE_KEY
. Lorsque le message "Accorder l'accès ?" s'affiche, messages, acceptez de fournir un accès secret.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. Installer la bibliothèque gemma
L'accélération matérielle sans frais de Colab est actuellement insuffisante pour exécuter ce notebook. Si vous utilisez le paiement à l'usage Colab ou Colab Pro, cliquez sur Modifier > Paramètres du notebook > Sélectionnez GPU A100 > Enregistrez pour activer l'accélération matérielle.
Vous devez ensuite installer la bibliothèque Google DeepMind gemma
à partir de github.com/google-deepmind/gemma
. Si vous obtenez une erreur concernant le "résolveur de dépendances de pip", vous pouvez généralement l'ignorer.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. Importer des bibliothèques
Ce notebook utilise Gemma (qui utilise Flax pour créer les couches de son réseau de neurones) et SentencePiece (pour la tokenisation).
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
Charger le modèle CodeGemma
Chargez le modèle CodeGemma avec kagglehub.model_download
, qui utilise trois arguments:
handle
: le gestionnaire de modèle de Kagglepath
(chaîne facultative) : chemin d'accès localforce_download
: (valeur booléenne facultative) force le téléchargement à nouveau du modèle.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
Vérifiez l'emplacement des pondérations du modèle et de la fonction de tokenisation, puis définissez les variables de chemin. Le répertoire de tokenisation se trouve dans le répertoire principal dans lequel vous avez téléchargé le modèle, tandis que les pondérations du modèle sont dans un sous-répertoire. Exemple :
- Le fichier de tokenisation
spm.model
se trouvera dans/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
. - Le point de contrôle du modèle se trouvera dans
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
Effectuer un échantillonnage/une inférence
Chargez et formatez le point de contrôle du modèle CodeGemma à l'aide de la méthode gemma.params.load_and_format_params
:
params = params_lib.load_and_format_params(CKPT_PATH)
Chargez la fonction de tokenisation CodeGemma, créée à l'aide de sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Pour charger automatiquement la configuration appropriée à partir du point de contrôle du modèle CodeGemma, utilisez gemma.transformer.TransformerConfig
. L'argument cache_size
correspond au nombre de pas de temps dans le cache Transformer
de CodeGemma. Instanciez ensuite le modèle CodeGemma en tant que model_2b
avec gemma.transformer.Transformer
(qui hérite de flax.linen.Module
).
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
Créez un sampler
avec gemma.sampler.Sampler
. Elle utilise le point de contrôle du modèle CodeGemma et la fonction de tokenisation.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
Créez des variables pour représenter les jetons de remplissage au milieu (Fim) et créez des fonctions d'assistance pour mettre en forme la requête et la sortie générée.
Examinons par exemple le code suivant:
def function(string):
assert function('asdf') == 'fdsa'
Nous aimerions renseigner function
pour que l'assertion contienne True
. Dans ce cas, le préfixe serait:
"def function(string):\n"
Et le suffixe serait:
"assert function('asdf') == 'fdsa'"
Nous formatons ensuite ceci en requête de type PREFIX-SUFFIX-MIDDLE (la section du milieu à remplir se trouve toujours à la fin de la requête):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
Créer une requête et effectuer des inférences Spécifiez le texte du préfixe before
et le texte du suffixe after
, puis générez la requête mise en forme à l'aide de la fonction d'assistance format_completion prompt
.
Vous pouvez modifier total_generation_steps
(le nombre d'étapes effectuées lors de la génération d'une réponse. Cet exemple utilise 100
pour préserver la mémoire de l'hôte).
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
En savoir plus
- Pour en savoir plus sur la bibliothèque
gemma
de Google DeepMind sur GitHub, qui contient les docstrings des modules que vous avez utilisés dans ce tutoriel, tels quegemma.params
,gemma.transformer
etgemma.sampler
. - Les bibliothèques suivantes possèdent leurs propres sites de documentation: core JAX, Flax et Orbax.
- Pour obtenir de la documentation sur la fonction de tokenisation et de détokenisation
sentencepiece
, consultez le dépôt GitHubsentencepiece
de Google. - Pour obtenir de la documentation sur
kagglehub
, accédez àREADME.md
dans le dépôt GitHubkagglehub
de Kaggle. - Découvrez comment utiliser des modèles Gemma avec Google Cloud Vertex AI.
- Si vous utilisez des TPU Google Cloud (versions 3-8 et ultérieures), veillez également à passer au dernier package
jax[tpu]
(!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), redémarrez l'environnement d'exécution et vérifiez que les versionsjax
etjaxlib
correspondent (!pip list | grep jax
). Cela peut éviter l'erreurRuntimeError
qui peut se produire en raison de la non-concordance des versionsjaxlib
etjax
. Pour obtenir d'autres instructions d'installation concernant JAX, consultez la documentation JAX.