Inferência com o CodeGemma usando JAX e Flax

Ver em ai.google.dev Executar no Google Colab Veja o código-fonte no GitHub

Apresentamos o CodeGemma, uma coleção de modelos de código aberto baseados nos modelos Gemma do Google DeepMind (Gemma Team e outros 2024). O CodeGemma é uma família de modelos abertos, leves e modernos, criados com a mesma pesquisa e tecnologia usada para criar os modelos do Gemini.

Continuando dos modelos pré-treinados do Gemma, os modelos do CodeGemma são treinados com mais de 500 a 1.000 bilhões de tokens primariamente de código, usando as mesmas arquiteturas da família de modelos Gemma. Como resultado, os modelos do CodeGemma alcançam um desempenho de código de última geração em ambos e geração de cargas de trabalho, mantendo uma sólida compreensão e raciocínio em escala.

O CodeGemma tem três variantes:

  • Um modelo pré-treinado com códigos 7B
  • Um modelo de código ajustado por instruções 7B
  • Um modelo 2B, treinado especificamente para preenchimento de código e geração aberta.

Este guia explica como usar o modelo CodeGemma com o Flax para uma tarefa de preenchimento de código.

Configuração

1. Configure o acesso do Kaggle para o CodeGemma

Para concluir este tutorial, primeiro siga as instruções de configuração em Configuração do Gemma, que mostram como fazer o seguinte:

  • Acesse o CodeGemma em kaggle.com.
  • Selecione um ambiente de execução do Colab com recursos suficientes (a GPU T4 não tem memória suficiente. Use a TPU v2) para executar o modelo do CodeGemma.
  • Gere e configure um nome de usuário do Kaggle e uma chave de API.

Depois de concluir a configuração do Gemma, vá para a próxima seção, em que você definirá variáveis de ambiente para o ambiente do Colab.

2. Defina as variáveis de ambiente

Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY. Quando a pergunta "Conceder acesso?" for exibida concordam em fornecer acesso a secrets.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Instalar a biblioteca gemma

No momento, a aceleração de hardware sem custo financeiro do Colab é insuficiente para executar este notebook. Se você estiver usando o Colab Pay as You Go ou o Colab Pro, clique em Editar > Configurações do notebook > Selecione GPU A100 > Salve para ativar a aceleração de hardware.

Em seguida, você precisa instalar a biblioteca gemma do Google DeepMind no github.com/google-deepmind/gemma. Se você receber um erro sobre "resolvedor de dependências do pip", geralmente ele poderá ser ignorado.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Importar bibliotecas

Este notebook usa o Gemma (que usa o Flax para criar as camadas de rede neural) e o SentencePiece (para tokenização).

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

Carregar o modelo do CodeGemma

Carregue o modelo do CodeGemma com kagglehub.model_download, que usa três argumentos:

  • handle: o identificador do modelo do Kaggle
  • path: (string opcional) o caminho local
  • force_download (booleano opcional): força o download do modelo novamente.
.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Verifique o local dos pesos do modelo e do tokenizador e, em seguida, defina as variáveis do caminho. O diretório tokenizador estará no diretório principal em que você fez o download do modelo, enquanto os pesos do modelo estarão em um subdiretório. Exemplo:

  • O arquivo tokenizador spm.model estará em /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • O checkpoint do modelo estará em /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Realizar amostragem/inferência

Carregue e formate o checkpoint do modelo CodeGemma com o método gemma.params.load_and_format_params:

params = params_lib.load_and_format_params(CKPT_PATH)

Carregue o tokenizador CodeGemma, construído usando sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

Para carregar automaticamente a configuração correta do checkpoint do modelo CodeGemma, use gemma.transformer.TransformerConfig. O argumento cache_size é o número de etapas de tempo no cache Transformer do CodeGemma. Em seguida, instancie o modelo CodeGemma como model_2b com gemma.transformer.Transformer (herdado de flax.linen.Module).

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

Crie uma sampler com gemma.sampler.Sampler. Ele usa o checkpoint do modelo CodeGemma e o tokenizador.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Crie algumas variáveis para representar os tokens de preenchimento no meio (fim) e crie algumas funções auxiliares para formatar o comando e a saída gerada.

Por exemplo, vamos analisar o seguinte código:

def function(string):
assert function('asdf') == 'fdsa'

Queremos preencher function para que a declaração mantenha True. Nesse caso, o prefixo seria:

"def function(string):\n"

E o sufixo seria:

"assert function('asdf') == 'fdsa'"

Em seguida, formatamos isso em um comando como PREFIX-SUFFIX-MIDDLE (a seção do meio que precisa ser preenchida está sempre no final do comando):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

Criar um comando e realizar inferências. Especifique o texto do prefixo before e o texto do sufixo after e gere o comando formatado usando a função auxiliar format_completion prompt.

Você pode ajustar o total_generation_steps, que é o número de etapas realizadas ao gerar uma resposta. Este exemplo usa 100 para preservar a memória do host.

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

Saiba mais