Ver en ai.google.dev | Ejecutar en Google Colab | Ver el código fuente en GitHub |
Presentamos CodeGemma, una colección de modelos de código abierto basados en los modelos Gemma de Google DeepMind (Gemma Team et al., 2024). CodeGemma es una familia de modelos abiertos ligeros y de última generación creados a partir de la misma investigación y tecnología que se usaron para crear los modelos de Gemini.
Continuando con los modelos previamente entrenados de Gemma, los modelos de CodeGemma se entrenan aún más con más de 500 a 1000 mil millones de tokens de principalmente código, usando las mismas arquitecturas que la familia de modelos Gemma. Como resultado, los modelos de CodeGemma logran un rendimiento de código de vanguardia y generación de datos, a la vez que se mantienen sólidas habilidades de comprensión y razonamiento a gran escala.
CodeGemma tiene 3 variantes:
- Un modelo previamente entrenado en código 7B
- Un modelo 7B de código ajustado a instrucciones
- Es un modelo 2B, entrenado específicamente para el relleno de código y la generación abierta.
En esta guía, aprenderás a usar el modelo CodeGemma con Flax para realizar una tarea de finalización de código.
Configuración
1. Configurar acceso a Kaggle para CodeGemma
Para completar este instructivo, primero debes seguir las instrucciones de configuración en Configuración de Gemma, que te muestran cómo hacer lo siguiente:
- Obtén acceso a CodeGemma en kaggle.com.
- Selecciona un entorno de ejecución de Colab con suficientes recursos (la GPU T4 tiene memoria insuficiente; usa TPU v2 en su lugar) para ejecutar el modelo de CodeGemma.
- Generar y configurar un nombre de usuario Kaggle y una clave de API.
Después de completar la configuración de Gemma, continúa con la siguiente sección, en la que establecerás variables de entorno para tu entorno de Colab.
2. Configure las variables de entorno
Configura variables de entorno para KAGGLE_USERNAME
y KAGGLE_KEY
. Cuando aparezca el mensaje “¿Quieres otorgar acceso?” mensajes, acepta proporcionar acceso al Secret.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. Instala la biblioteca gemma
Por el momento, la aceleración de hardware de Colab es insuficiente para ejecutar este notebook. Si usas Colab Pay As You Go o Colab Pro, haz clic en Editar > Configuración del notebook > Selecciona GPU A100 > Guarda para habilitar la aceleración de hardware.
A continuación, debes instalar la biblioteca gemma
de Google DeepMind desde github.com/google-deepmind/gemma
. Si recibes un error sobre el “agente de resolución de dependencias de pip”, por lo general, puedes ignorarlo.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. Importa las bibliotecas
Este notebook usa Gemma (que usa Flax para compilar sus capas de red neuronal) y SentencePiece (para la asignación de token).
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
Carga el modelo de CodeGemma
Carga el modelo de CodeGemma con kagglehub.model_download
, que toma tres argumentos:
handle
: el controlador del modelo de Kagglepath
: (Cadena opcional) es la ruta de acceso local.force_download
: (booleano opcional) Obliga a volver a descargar el modelo.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
Verifica la ubicación de los pesos del modelo y el tokenizador, luego establece las variables de la ruta de acceso. El directorio del tokenizador estará en el directorio principal donde descargaste el modelo, mientras que los pesos del modelo estarán en un subdirectorio. Por ejemplo:
- El archivo del tokenizador
spm.model
estará en/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
- El punto de control del modelo estará en
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
Realizar muestreo o inferencia
Carga el punto de control del modelo CodeGemma y dale formato con el método gemma.params.load_and_format_params
:
params = params_lib.load_and_format_params(CKPT_PATH)
Carga el tokenizador de CodeGemma, construido con sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Para cargar automáticamente la configuración correcta desde el punto de control del modelo de CodeGemma, usa gemma.transformer.TransformerConfig
. El argumento cache_size
es la cantidad de pasos de la caché Transformer
de CodeGemma. Luego, crea una instancia del modelo CodeGemma como model_2b
con gemma.transformer.Transformer
(que se hereda de flax.linen.Module
).
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
Crea un sampler
con gemma.sampler.Sampler
. Utiliza el punto de control del modelo CodeGemma y el tokenizador.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
Crea algunas variables para representar los tokens de completar el medio (fim) y crea algunas funciones auxiliares para dar formato a la instrucción y al resultado generado.
Por ejemplo, veamos el siguiente código:
def function(string):
assert function('asdf') == 'fdsa'
Nos gustaría completar function
para que la aserción contenga True
. En este caso, el prefijo sería el siguiente:
"def function(string):\n"
Y el sufijo sería:
"assert function('asdf') == 'fdsa'"
Luego, formamos esto en una instrucción como PREFIX-SUFFIX-MIDDLE (la sección del medio que se debe completar siempre está al final de la instrucción):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
Crea una instrucción y realiza inferencias. Especifica el texto del prefijo before
y el texto del sufijo after
, y genera la instrucción con formato usando la función auxiliar format_completion prompt
.
Puedes modificar total_generation_steps
(la cantidad de pasos que se realizan cuando se genera una respuesta; en este ejemplo, se usa 100
para preservar la memoria del host).
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
Más información
- Puedes obtener más información sobre la biblioteca
gemma
de Google DeepMind en GitHub, que contiene docstrings de los módulos que usaste en este instructivo, comogemma.params
,gemma.transformer
ygemma.sampler
. - Las siguientes bibliotecas tienen sus propios sitios de documentación: core JAX, Flax y Orbax.
- Para ver la documentación del tokenizador/detokenizador
sentencepiece
, consulta el repositorio de GitHubsentencepiece
de Google. - Para ver la documentación de
kagglehub
, consultaREADME.md
en el repositorio de GitHubkagglehub
de Kaggle. - Aprende a usar modelos de Gemma con Vertex AI de Google Cloud.
- Si usas Google Cloud TPU (v3-8 y versiones posteriores), asegúrate de actualizar también al paquete
jax[tpu]
más reciente (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), reiniciar el entorno de ejecución y verificar que las versionesjax
yjaxlib
coincidan (!pip list | grep jax
). Esto puede evitar que se produzca elRuntimeError
debido a que las versiones dejaxlib
yjax
no coinciden. Para conocer más instrucciones de instalación de JAX, consulta los documentos de JAX.