Visualizza su ai.google.dev | Esegui in Google Colab | Visualizza il codice sorgente su GitHub |
Presentiamo CodeGemma, una raccolta di modelli di codice aperto basati sui modelli Gemma di Google DeepMind (Gemma Team et al., 2024). CodeGemma è una famiglia di modelli aperti leggeri e all'avanguardia basati sulla stessa ricerca e tecnologia utilizzate per creare i modelli Gemini.
Partendo dai modelli preaddestrati di Gemma, i modelli CodeGemma vengono addestrati ulteriormente su oltre 500-1000 miliardi di token di codice principalmente utilizzando le stesse architetture della famiglia di modelli Gemma. Il risultato è che i modelli CodeGemma ottengono prestazioni all'avanguardia nel codice sia di sviluppo e generazione, mantenendo al contempo solide di comprensione e ragionamento su vasta scala.
CodeGemma ha 3 varianti:
- Un modello preaddestrato con codice 7B
- Un modello di codice ottimizzato per l'istruzione di 7 miliardi
- Un modello 2B, addestrato specificamente per il riempimento del codice e la generazione aperta.
Questa guida illustra l'utilizzo del modello CodeGemma con Flax per un'attività di completamento del codice.
Configurazione
1. Configura l'accesso a Kaggle per CodeGemma
Per completare questo tutorial, devi prima seguire le istruzioni di configurazione nella configurazione di Gemma, che mostrano come fare:
- Accedi a CodeGemma su kaggle.com.
- Seleziona un runtime Colab con risorse sufficienti (GPU T4 ha memoria insufficiente, usa invece TPU v2) per eseguire il modello CodeGemma.
- Genera e configura un nome utente e una chiave API Kaggle.
Dopo aver completato la configurazione di Gemma, passa alla sezione successiva, in cui imposterai le variabili di ambiente per il tuo ambiente Colab.
2. Imposta le variabili di ambiente
Imposta le variabili di ambiente per KAGGLE_USERNAME
e KAGGLE_KEY
. Quando viene visualizzata la richiesta "Vuoi concedere l'accesso?" messaggi, accetti di fornire l'accesso al secret.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. Installa la libreria gemma
Al momento l'accelerazione hardware Colab senza costi è insufficiente per eseguire questo blocco note. Se utilizzi Colab Pay As You Go o Colab Pro, fai clic su Modifica > Impostazioni blocco note > Seleziona GPU A100 > Salva per attivare l'accelerazione hardware.
Successivamente, devi installare la libreria Google DeepMind gemma
da github.com/google-deepmind/gemma
. Se ricevi un errore relativo al " resolver di dipendenze di pip", in genere puoi ignorarlo.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. Importa librerie
Questo blocco note utilizza Gemma (che usa Flax per creare i livelli della rete neurale) e SentencePiece (per la tokenizzazione).
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
Carica il modello CodeGemma
Carica il modello CodeGemma con kagglehub.model_download
, che accetta tre argomenti:
handle
: l'handle del modello di Kagglepath
: (stringa facoltativa) il percorso localeforce_download
: (booleano facoltativo) forza a scaricare di nuovo il modello
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
Controlla la posizione dei pesi del modello e del tokenizzatore, quindi imposta le variabili di percorso. La directory del tokenizzatore si troverà nella directory principale in cui hai scaricato il modello, mentre i pesi del modello saranno in una sottodirectory. Ad esempio:
- Il file tokenizzatore
spm.model
sarà in/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
- Il checkpoint del modello sarà in
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
Eseguire campionamento/inferenza
Carica e formatta il checkpoint del modello CodeGemma con il metodo gemma.params.load_and_format_params
:
params = params_lib.load_and_format_params(CKPT_PATH)
Carica il tokenizzatore CodeGemma, creato con sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Per caricare automaticamente la configurazione corretta dal checkpoint del modello CodeGemma, utilizza gemma.transformer.TransformerConfig
. L'argomento cache_size
è il numero di passi temporali nella cache Transformer
di CodeGemma. In seguito, crea un'istanza del modello CodeGemma come model_2b
con gemma.transformer.Transformer
(che eredita da flax.linen.Module
).
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
Crea un sampler
con gemma.sampler.Sampler
. Utilizza il checkpoint del modello CodeGemma e il tokenizzatore.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
Crea alcune variabili per rappresentare i token di riempimento (fim) e crea alcune funzioni helper per formattare il prompt e l'output generato.
Esaminiamo ad esempio il seguente codice:
def function(string):
assert function('asdf') == 'fdsa'
Vogliamo compilare il function
in modo che l'asserzione contenga True
. In questo caso, il prefisso sarebbe:
"def function(string):\n"
Il suffisso sarebbe:
"assert function('asdf') == 'fdsa'"
Quindi lo formattiamo in un prompt come proxy-SUFFIX-MIDDLE (la sezione centrale da compilare è sempre alla fine del prompt):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
Crea un prompt ed esegui l'inferenza. Specifica il testo del prefisso before
e il testo del suffisso after
e genera il prompt formattato utilizzando la funzione helper format_completion prompt
.
Puoi modificare total_generation_steps
(il numero di passaggi eseguiti durante la generazione di una risposta; questo esempio utilizza 100
per preservare la memoria dell'host).
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
Scopri di più
- Puoi scoprire di più sulla libreria
gemma
di Google DeepMind su GitHub, che contiene le stringhe di documenti dei moduli che hai utilizzato in questo tutorial, ad esempiogemma.params
,gemma.transformer
egemma.sampler
. - Le seguenti librerie dispongono di siti di documentazione proprietari: JAX di base, Flax e Orbax.
- Per la documentazione relativa al tokenizzatore/detokenizzatore
sentencepiece
, consulta il repository GitHubsentencepiece
di Google. - Per la documentazione relativa a
kagglehub
, dai un'occhiata aREADME.md
nel repository GitHubkagglehub
di Kaggle. - Scopri come utilizzare i modelli Gemma con Vertex AI di Google Cloud.
- Se utilizzi Google Cloud TPU (v3-8 e successive), assicurati di eseguire anche l'aggiornamento al pacchetto
jax[tpu]
più recente (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), riavvia il runtime e controlla che le versionijax
ejaxlib
corrispondano (!pip list | grep jax
). In questo modo è possibile evitare i casi in cuiRuntimeError
potrebbe verificarsi a causa della mancata corrispondenza tra le versioni dijaxlib
ejax
. Per ulteriori istruzioni sull'installazione di JAX, consulta la documentazione JAX.