Inferenza con CodeGemma tramite JAX e Flax

Visualizza su ai.google.dev Esegui in Google Colab Visualizza il codice sorgente su GitHub

Presentiamo CodeGemma, una raccolta di modelli di codice aperto basati sui modelli Gemma di Google DeepMind (Gemma Team et al., 2024). CodeGemma è una famiglia di modelli aperti leggeri e all'avanguardia basati sulla stessa ricerca e tecnologia utilizzate per creare i modelli Gemini.

Partendo dai modelli preaddestrati di Gemma, i modelli CodeGemma vengono addestrati ulteriormente su oltre 500-1000 miliardi di token di codice principalmente utilizzando le stesse architetture della famiglia di modelli Gemma. Il risultato è che i modelli CodeGemma ottengono prestazioni all'avanguardia nel codice sia di sviluppo e generazione, mantenendo al contempo solide di comprensione e ragionamento su vasta scala.

CodeGemma ha 3 varianti:

  • Un modello preaddestrato con codice 7B
  • Un modello di codice ottimizzato per l'istruzione di 7 miliardi
  • Un modello 2B, addestrato specificamente per il riempimento del codice e la generazione aperta.

Questa guida illustra l'utilizzo del modello CodeGemma con Flax per un'attività di completamento del codice.

Configurazione

1. Configura l'accesso a Kaggle per CodeGemma

Per completare questo tutorial, devi prima seguire le istruzioni di configurazione nella configurazione di Gemma, che mostrano come fare:

  • Accedi a CodeGemma su kaggle.com.
  • Seleziona un runtime Colab con risorse sufficienti (GPU T4 ha memoria insufficiente, usa invece TPU v2) per eseguire il modello CodeGemma.
  • Genera e configura un nome utente e una chiave API Kaggle.

Dopo aver completato la configurazione di Gemma, passa alla sezione successiva, in cui imposterai le variabili di ambiente per il tuo ambiente Colab.

2. Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME e KAGGLE_KEY. Quando viene visualizzata la richiesta "Vuoi concedere l'accesso?" messaggi, accetti di fornire l'accesso al secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Installa la libreria gemma

Al momento l'accelerazione hardware Colab senza costi è insufficiente per eseguire questo blocco note. Se utilizzi Colab Pay As You Go o Colab Pro, fai clic su Modifica > Impostazioni blocco note > Seleziona GPU A100 > Salva per attivare l'accelerazione hardware.

Successivamente, devi installare la libreria Google DeepMind gemma da github.com/google-deepmind/gemma. Se ricevi un errore relativo al " resolver di dipendenze di pip", in genere puoi ignorarlo.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Importa librerie

Questo blocco note utilizza Gemma (che usa Flax per creare i livelli della rete neurale) e SentencePiece (per la tokenizzazione).

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

Carica il modello CodeGemma

Carica il modello CodeGemma con kagglehub.model_download, che accetta tre argomenti:

  • handle: l'handle del modello di Kaggle
  • path: (stringa facoltativa) il percorso locale
  • force_download: (booleano facoltativo) forza a scaricare di nuovo il modello
di Gemini Advanced.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Controlla la posizione dei pesi del modello e del tokenizzatore, quindi imposta le variabili di percorso. La directory del tokenizzatore si troverà nella directory principale in cui hai scaricato il modello, mentre i pesi del modello saranno in una sottodirectory. Ad esempio:

  • Il file tokenizzatore spm.model sarà in /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • Il checkpoint del modello sarà in /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Eseguire campionamento/inferenza

Carica e formatta il checkpoint del modello CodeGemma con il metodo gemma.params.load_and_format_params:

params = params_lib.load_and_format_params(CKPT_PATH)

Carica il tokenizzatore CodeGemma, creato con sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

Per caricare automaticamente la configurazione corretta dal checkpoint del modello CodeGemma, utilizza gemma.transformer.TransformerConfig. L'argomento cache_size è il numero di passi temporali nella cache Transformer di CodeGemma. In seguito, crea un'istanza del modello CodeGemma come model_2b con gemma.transformer.Transformer (che eredita da flax.linen.Module).

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

Crea un sampler con gemma.sampler.Sampler. Utilizza il checkpoint del modello CodeGemma e il tokenizzatore.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Crea alcune variabili per rappresentare i token di riempimento (fim) e crea alcune funzioni helper per formattare il prompt e l'output generato.

Esaminiamo ad esempio il seguente codice:

def function(string):
assert function('asdf') == 'fdsa'

Vogliamo compilare il function in modo che l'asserzione contenga True. In questo caso, il prefisso sarebbe:

"def function(string):\n"

Il suffisso sarebbe:

"assert function('asdf') == 'fdsa'"

Quindi lo formattiamo in un prompt come proxy-SUFFIX-MIDDLE (la sezione centrale da compilare è sempre alla fine del prompt):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

Crea un prompt ed esegui l'inferenza. Specifica il testo del prefisso before e il testo del suffisso after e genera il prompt formattato utilizzando la funzione helper format_completion prompt.

Puoi modificare total_generation_steps (il numero di passaggi eseguiti durante la generazione di una risposta; questo esempio utilizza 100 per preservare la memoria dell'host).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

Scopri di più