Inferenz mit CodeGemma unter Verwendung von JAX und Flax

Auf ai.google.dev ansehen In Google Colab ausführen Quelle auf GitHub ansehen

Wir präsentieren CodeGemma, eine Sammlung offener Codemodelle, die auf den Gemma-Modellen von Google DeepMind basiert (Gemma Team et al., 2024). CodeGemma ist eine Familie von leichten, hochmodernen offenen Modellen, die auf derselben Forschung und Technologie basieren, die auch zur Erstellung der Gemini-Modelle verwendet wurden.

Im Anschluss an vortrainierte Gemma-Modelle werden CodeGemma-Modelle weiter mit mehr als 500 bis 1.000 Milliarden Tokens trainiert. Dabei werden Architekturen wie in der Gemma-Modellfamilie. Infolgedessen erzielen CodeGemma-Modelle bei der Vervollständigung und Generieren von Aufgaben, während Sie gleichzeitig Verständnis- und Logikfähigkeiten in großem Maßstab.

CodeGemma hat drei Varianten:

  • Ein vortrainiertes Modell mit 7 Mrd. Code
  • Ein auf 7 Milliarden abgestimmtes Codemodell
  • Ein 2-Milliarden-Modell, das speziell für die Code-Füllung und offene Generierung trainiert wurde.

Dieser Leitfaden führt Sie durch die Verwendung des CodeGemma-Modells mit Flax für eine Codevervollständigungsaufgabe.

Einrichtung

1. Kaggle-Zugriff für CodeGemma einrichten

Um diese Anleitung abzuschließen, müssen Sie zuerst der Anleitung unter Gemma-Einrichtung folgen. Sie erfahren, wie Sie Folgendes tun:

  • Erhalte über kaggle.com Zugriff auf CodeGemma.
  • Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen aus (T4-GPU hat unzureichenden Arbeitsspeicher, verwenden Sie stattdessen TPU v2), um das CodeGemma-Modell auszuführen.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Nachdem Sie die Gemma-Einrichtung abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort. Dort legen Sie Umgebungsvariablen für Ihre Colab-Umgebung fest.

2. Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest. Wenn die Aufforderung „Zugriff erlauben?“ angezeigt wird, -Nachrichten, erklären Sie sich damit einverstanden, Secret-Zugriff bereitzustellen.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma-Bibliothek installieren

Die kostenlose Colab-Hardwarebeschleunigung ist derzeit nicht ausreichend, um dieses Notebook auszuführen. Wenn Sie Colab Pay As You Go oder Colab Pro verwenden, klicken Sie auf Bearbeiten > Notebook-Einstellungen > Wählen Sie A100 GPU aus > Speichern, um die Hardwarebeschleunigung zu aktivieren.

Als Nächstes müssen Sie die gemma-Bibliothek von Google DeepMind von github.com/google-deepmind/gemma installieren. Wenn Sie einen Fehler zum „Abhängigkeitsauflöser von pip“ erhalten, können Sie ihn in der Regel ignorieren.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Bibliotheken importieren

Dieses Notebook verwendet Gemma (das Flax zum Erstellen der neuronalen Netzwerkschichten verwendet) und SentencePiece (für die Tokenisierung).

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

CodeGemma-Modell laden

Laden Sie das CodeGemma-Modell mit kagglehub.model_download. Dafür werden drei Argumente benötigt:

  • handle: Das Modell-Handle von Kaggle
  • path: (optionaler String) der lokale Pfad
  • force_download: (optionaler boolescher Wert) Erzwingt das erneute Herunterladen des Modells
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Überprüfen Sie den Speicherort der Modellgewichtungen und des Tokenizers und legen Sie dann die Pfadvariablen fest. Das Tokenizer-Verzeichnis befindet sich im Hauptverzeichnis, in das Sie das Modell heruntergeladen haben, und die Modellgewichtungen befinden sich in einem Unterverzeichnis. Beispiel:

  • Die Tokenizer-Datei spm.model befindet sich in /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • Der Modellprüfpunkt befindet sich in /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Stichprobenerhebung/Inferenz durchführen

Laden und formatieren Sie den CodeGemma-Modellprüfpunkt mit der Methode gemma.params.load_and_format_params:

params = params_lib.load_and_format_params(CKPT_PATH)

Laden Sie den CodeGemma-Tokenizer, der mit sentencepiece.SentencePieceProcessor erstellt wurde:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

Verwenden Sie gemma.transformer.TransformerConfig, um automatisch die richtige Konfiguration aus dem CodeGemma-Modellprüfpunkt zu laden. Das Argument cache_size ist die Anzahl der zeitlichen Schritte im Transformer-Cache von CodeGemma. Instanziieren Sie anschließend das CodeGemma-Modell als model_2b mit gemma.transformer.Transformer (das von flax.linen.Module übernommen wird).

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

Erstellen Sie eine sampler mit gemma.sampler.Sampler. Dabei werden der CodeGemma-Modellprüfpunkt und der Tokenizer verwendet.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Erstellen Sie einige Variablen, die die Fill-in-the-Middle-Tokens (fim) darstellen, und erstellen Sie einige Hilfsfunktionen, um den Prompt und die generierte Ausgabe zu formatieren.

Sehen wir uns zum Beispiel den folgenden Code an:

def function(string):
assert function('asdf') == 'fdsa'

Wir möchten die function ausfüllen, damit die Assertion True enthält. In diesem Fall lautet das Präfix:

"def function(string):\n"

Das Suffix wäre:

"assert function('asdf') == 'fdsa'"

Anschließend formatieren wir dies in einer Eingabeaufforderung als PREFIX-SUFFIX-MIDDLE (der zu füllende mittlere Abschnitt steht immer am Ende der Aufforderung):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

Erstellen Sie einen Prompt und führen Sie eine Inferenz durch. Geben Sie den Präfix-Text before und das Suffix after an und generieren Sie den formatierten Prompt mit der Hilfsfunktion format_completion prompt.

Sie können total_generation_steps optimieren (die Anzahl der Schritte, die beim Generieren einer Antwort ausgeführt werden; in diesem Beispiel wird 100 verwendet, um den Hostspeicher beizubehalten).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

Weitere Informationen