Ver em ai.google.dev | Executar no Google Colab | Veja o código-fonte no GitHub |
Apresentamos o CodeGemma, uma coleção de modelos de código aberto baseados nos modelos Gemma do Google DeepMind (Gemma Team e outros 2024). O CodeGemma é uma família de modelos abertos, leves e modernos, criados com a mesma pesquisa e tecnologia usada para criar os modelos do Gemini.
Continuando dos modelos pré-treinados do Gemma, os modelos do CodeGemma são treinados com mais de 500 a 1.000 bilhões de tokens primariamente de código, usando as mesmas arquiteturas da família de modelos Gemma. Como resultado, os modelos do CodeGemma alcançam um desempenho de código de última geração em ambos e geração de cargas de trabalho, mantendo uma sólida compreensão e raciocínio em escala.
O CodeGemma tem três variantes:
- Um modelo pré-treinado com códigos 7B
- Um modelo de código ajustado por instruções 7B
- Um modelo 2B, treinado especificamente para preenchimento de código e geração aberta.
Este guia explica como usar o modelo CodeGemma com o Flax para uma tarefa de preenchimento de código.
Configuração
1. Configure o acesso do Kaggle para o CodeGemma
Para concluir este tutorial, primeiro siga as instruções de configuração em Configuração do Gemma, que mostram como fazer o seguinte:
- Acesse o CodeGemma em kaggle.com.
- Selecione um ambiente de execução do Colab com recursos suficientes (a GPU T4 não tem memória suficiente. Use a TPU v2) para executar o modelo do CodeGemma.
- Gere e configure um nome de usuário do Kaggle e uma chave de API.
Depois de concluir a configuração do Gemma, vá para a próxima seção, em que você definirá variáveis de ambiente para o ambiente do Colab.
2. Defina as variáveis de ambiente
Defina as variáveis de ambiente para KAGGLE_USERNAME
e KAGGLE_KEY
. Quando a pergunta "Conceder acesso?" for exibida concordam em fornecer acesso a secrets.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. Instalar a biblioteca gemma
No momento, a aceleração de hardware sem custo financeiro do Colab é insuficiente para executar este notebook. Se você estiver usando o Colab Pay as You Go ou o Colab Pro, clique em Editar > Configurações do notebook > Selecione GPU A100 > Salve para ativar a aceleração de hardware.
Em seguida, você precisa instalar a biblioteca gemma
do Google DeepMind no github.com/google-deepmind/gemma
. Se você receber um erro sobre "resolvedor de dependências do pip", geralmente ele poderá ser ignorado.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. Importar bibliotecas
Este notebook usa o Gemma (que usa o Flax para criar as camadas de rede neural) e o SentencePiece (para tokenização).
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
Carregar o modelo do CodeGemma
Carregue o modelo do CodeGemma com kagglehub.model_download
, que usa três argumentos:
handle
: o identificador do modelo do Kagglepath
: (string opcional) o caminho localforce_download
(booleano opcional): força o download do modelo novamente.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
Verifique o local dos pesos do modelo e do tokenizador e, em seguida, defina as variáveis do caminho. O diretório tokenizador estará no diretório principal em que você fez o download do modelo, enquanto os pesos do modelo estarão em um subdiretório. Exemplo:
- O arquivo tokenizador
spm.model
estará em/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
- O checkpoint do modelo estará em
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
Realizar amostragem/inferência
Carregue e formate o checkpoint do modelo CodeGemma com o método gemma.params.load_and_format_params
:
params = params_lib.load_and_format_params(CKPT_PATH)
Carregue o tokenizador CodeGemma, construído usando sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Para carregar automaticamente a configuração correta do checkpoint do modelo CodeGemma, use gemma.transformer.TransformerConfig
. O argumento cache_size
é o número de etapas de tempo no cache Transformer
do CodeGemma. Em seguida, instancie o modelo CodeGemma como model_2b
com gemma.transformer.Transformer
(herdado de flax.linen.Module
).
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
Crie uma sampler
com gemma.sampler.Sampler
. Ele usa o checkpoint do modelo CodeGemma e o tokenizador.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
Crie algumas variáveis para representar os tokens de preenchimento no meio (fim) e crie algumas funções auxiliares para formatar o comando e a saída gerada.
Por exemplo, vamos analisar o seguinte código:
def function(string):
assert function('asdf') == 'fdsa'
Queremos preencher function
para que a declaração mantenha True
. Nesse caso, o prefixo seria:
"def function(string):\n"
E o sufixo seria:
"assert function('asdf') == 'fdsa'"
Em seguida, formatamos isso em um comando como PREFIX-SUFFIX-MIDDLE (a seção do meio que precisa ser preenchida está sempre no final do comando):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
Criar um comando e realizar inferências. Especifique o texto do prefixo before
e o texto do sufixo after
e gere o comando formatado usando a função auxiliar format_completion prompt
.
Você pode ajustar o total_generation_steps
, que é o número de etapas realizadas ao gerar uma resposta. Este exemplo usa 100
para preservar a memória do host.
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
Saiba mais
- Saiba mais sobre a biblioteca
gemma
do Google DeepMind no GitHub (link em inglês), que contém docstrings dos módulos usados neste tutorial, comogemma.params
(links em inglês).gemma.transformer
egemma.sampler
- As bibliotecas a seguir têm seus próprios sites de documentação: core JAX, Flax e Orbax.
- Para consultar a documentação do tokenizador/destokenizador do
sentencepiece
, consulte o repositório do GitHub dosentencepiece
do Google (em inglês). - Para conferir a documentação de
kagglehub
, confiraREADME.md
no repositório do GitHub dokagglehub
do Kaggle (links em inglês). - Saiba como usar modelos do Gemma com a Vertex AI do Google Cloud.
- Se você estiver usando TPUs do Google Cloud (v3-8 e mais recentes), atualize também para o pacote
jax[tpu]
mais recente (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), reinicie o ambiente de execução e verifique se as versõesjax
ejaxlib
correspondem (!pip list | grep jax
). Isso pode evitar aRuntimeError
que pode surgir devido à incompatibilidade de versão dejaxlib
ejax
. Para mais instruções de instalação do JAX, consulte os documentos do JAX.