Вывод с помощью CodeGemma с использованием JAX и Flax

Посмотреть на ai.google.dev Запустить в Google Colab Посмотреть исходный код на GitHub

Мы представляем CodeGemma, коллекцию моделей открытого кода, основанную на моделях Gemma компании Google DeepMind (Gemma Team et al., 2024). CodeGemma — это семейство легких современных открытых моделей, созданных на основе тех же исследований и технологий, которые использовались при создании моделей Gemini.

Продолжая предварительно обученные модели Gemma, модели CodeGemma дополнительно обучаются на более чем 500–1000 миллиардах токенов, состоящих в основном из кода, с использованием той же архитектуры, что и семейство моделей Gemma. В результате модели CodeGemma обеспечивают современную производительность кода как при выполнении задач завершения, так и при генерации, сохраняя при этом глубокие навыки понимания и рассуждения в масштабе.

CodeGemma имеет 3 варианта:

  • Предварительно обученная модель кода 7B
  • Модель кода 7B, настроенная с помощью инструкций
  • Модель 2B, специально обученная для заполнения кода и открытой генерации.

В этом руководстве рассказывается, как использовать модель CodeGemma с Flax для выполнения задачи завершения кода.

Настраивать

1. Настройте доступ Kaggle для CodeGemma.

Чтобы выполнить это руководство, сначала необходимо следовать инструкциям по настройке на странице Gemma setup , которые показывают, как сделать следующее:

  • Получите доступ к CodeGemma на kaggle.com .
  • Выберите среду выполнения Colab с достаточными ресурсами ( графическому процессору T4 недостаточно памяти, вместо этого используйте TPU v2 ) для запуска модели CodeGemma.
  • Создайте и настройте имя пользователя Kaggle и ключ API.

После завершения настройки Gemma перейдите к следующему разделу, где вы установите переменные среды для вашей среды Colab.

2. Установите переменные среды

Установите переменные среды для KAGGLE_USERNAME и KAGGLE_KEY . При появлении запроса «Предоставить доступ?» сообщения, согласитесь предоставить секретный доступ.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Установите библиотеку gemma

Бесплатного аппаратного ускорения Colab в настоящее время недостаточно для работы этого ноутбука. Если вы используете Colab Pay As You Go или Colab Pro , нажмите «Редактировать» > «Настройки ноутбука» > «Выбрать графический процессор A100» > «Сохранить» , чтобы включить аппаратное ускорение.

Далее вам необходимо установить библиотеку gemma Google DeepMind с github.com/google-deepmind/gemma . Если вы получаете сообщение об ошибке «преобразователь зависимостей pip», обычно вы можете игнорировать его.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. Импортируйте библиотеки

В этом блокноте используется Gemma (которая использует Flax для построения слоев нейронной сети) и SentencePiece (для токенизации).

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

Загрузите модель CodeGemma

Загрузите модель CodeGemma с помощью kagglehub.model_download , который принимает три аргумента:

  • handle : ручка модели от Kaggle.
  • path : (Необязательная строка) Локальный путь
  • force_download : (Необязательное логическое значение) Принудительно повторно загрузить модель.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

Проверьте расположение весов модели и токенизатора, затем установите переменные пути. Каталог токенизатора будет находиться в основном каталоге, в который вы загрузили модель, а веса модели будут находиться в подкаталоге. Например:

  • Файл токенизатора spm.model будет находиться в /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • Контрольная точка модели будет находиться в /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

Выполнить выборку/вывод

Загрузите и отформатируйте контрольную точку модели CodeGemma с помощью метода gemma.params.load_and_format_params :

params = params_lib.load_and_format_params(CKPT_PATH)

Загрузите токенизатор CodeGemma, созданный с помощью sentencepiece.SentencePieceProcessor :

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

Чтобы автоматически загрузить правильную конфигурацию из контрольной точки модели CodeGemma, используйте gemma.transformer.TransformerConfig . Аргумент cache_size — это количество временных шагов в кеше CodeGemma Transformer . После этого создайте экземпляр модели CodeGemma как model_2b с помощью gemma.transformer.Transformer (который наследуется от flax.linen.Module ).

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

Создайте sampler с помощью gemma.sampler.Sampler . Он использует контрольную точку модели CodeGemma и токенизатор.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Создайте несколько переменных для представления токенов заполнения посередине (fim) и создайте несколько вспомогательных функций для форматирования приглашения и сгенерированного вывода.

Например, давайте посмотрим на следующий код:

def function(string):
assert function('asdf') == 'fdsa'

Мы хотели бы заполнить function так, чтобы утверждение выполнялось True . В этом случае префикс будет таким:

"def function(string):\n"

И суффикс будет:

"assert function('asdf') == 'fdsa'"

Затем мы форматируем это в приглашение как PREFIX-SUFFIX-MIDDLE (средний раздел, который необходимо заполнить, всегда находится в конце приглашения):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

Создайте подсказку и выполните вывод. Укажите префикс before текстом и суффикс after текста и сгенерируйте форматированное приглашение с помощью вспомогательной функции format_completion prompt .

Вы можете настроить total_generation_steps (количество шагов, выполняемых при генерации ответа — в этом примере используется 100 для экономии памяти хоста).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

Узнать больше

  • Вы можете узнать больше о библиотеке gemma Google DeepMind на GitHub , которая содержит строки документации модулей, которые вы использовали в этом руководстве, таких как gemma.params , gemma.transformer и gemma.sampler .
  • Следующие библиотеки имеют собственные сайты документации: core JAX , Flax и Orbax .
  • Документацию по токенизатору/детокенизатору sentencepiece можно найти в репозитории Google sentencepiece на GitHub .
  • Документацию kagglehub можно найти README.md в репозитории kagglehub на GitHub .
  • Узнайте, как использовать модели Gemma с Google Cloud Vertex AI .
  • Если вы используете Google Cloud TPU (v3-8 и новее), обязательно также обновите последнюю версию пакета jax[tpu] ( !pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ), перезапустите среду выполнения и проверьте соответствие версий jax и jaxlib ( !pip list | grep jax ). Это может предотвратить RuntimeError , которая может возникнуть из-за несоответствия версий jaxlib и jax . Дополнительные инструкции по установке JAX см. в документации JAX .