JAX 및 Flax를 사용하여 CodeGemma로 추론

ai.google.dev에서 보기 Google Colab에서 실행 GitHub에서 소스 보기

Google DeepMind의 Gemma 모델 (Gemma Team 외, 2024년). CodeGemma는 Gemini 모델을 만드는 데 사용된 것과 동일한 연구와 기술을 바탕으로 빌드된 최첨단 경량 개방형 모델 제품군입니다.

Gemma 선행 학습된 모델에서 계속 진행하면 CodeGemma 모델은 주로 500~1,000억 개 이상의 토큰으로 구성된 Gemma 모델군과 동일한 아키텍처를 사용합니다 결과적으로 CodeGemma 모델은 완성된 두 코드 모두에서 최고의 코드 성능을 달성합니다. 생성 작업을 자동화하는 동시에 대규모 언어 모델을 학습시킬 수 있습니다

CodeGemma에는 세 가지 변형이 있습니다.

  • 7B 코드 선행 학습된 모델
  • 7B 명령 조정 코드 모델
  • 코드 입력 및 개방형 생성을 위해 특별히 학습된 2B 모델입니다.

이 가이드에서는 코드 완성 작업을 위해 Flax와 함께 CodeGemma 모델을 사용하는 방법을 설명합니다.

설정

1. CodeGemma에 Kaggle 액세스 권한 설정하기

이 튜토리얼을 완료하려면 먼저 Gemma 설정에서 다음을 수행하는 방법을 보여주는 설정 안내를 따라야 합니다.

  • kaggle.com에서 CodeGemma에 액세스하세요.
  • 리소스가 충분한 Colab 런타임 (T4 GPU 메모리가 부족합니다. 대신 TPU v2를 사용하세요)을 선택하여 CodeGemma 모델을 실행합니다.
  • Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

2. 환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 비밀 액세스 제공에 동의해야 합니다.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma 라이브러리 설치

현재 무료 Colab 하드웨어 가속으로는 이 노트북을 실행할 수 없습니다. Colab 종량제 또는 Colab Pro를 사용하는 경우 수정을 클릭합니다. 노트북 설정 > A100 GPU를 선택합니다. 저장하여 하드웨어 가속을 사용 설정합니다.

다음으로 github.com/google-deepmind/gemma에서 Google DeepMind gemma 라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하면 일반적으로 무시해도 됩니다.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. 라이브러리 가져오기

이 노트북은 Gemma (Flax를 사용하여 신경망 레이어를 빌드하는)와 SentencePiece (토큰화용)를 사용합니다.

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

CodeGemma 모델 로드

다음 세 가지 인수를 사용하는 kagglehub.model_download를 사용하여 CodeGemma 모델을 로드합니다.

  • handle: Kaggle의 모델 핸들
  • path: (선택사항 문자열) 로컬 경로
  • force_download: (선택적 불리언) 모델을 강제로 다시 다운로드합니다.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
드림

모델 가중치와 tokenizer의 위치를 확인한 다음 경로 변수를 설정합니다. tokenizer 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.

  • spm.model tokenizer 파일은 /LOCAL/PATH/TO/codegemma/flax/2b-pt/3에 있습니다.
  • 모델 체크포인트가 /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt에 위치하게 됩니다.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

샘플링/추론 수행

gemma.params.load_and_format_params 메서드를 사용하여 CodeGemma 모델 체크포인트를 로드하고 형식을 지정합니다.

params = params_lib.load_and_format_params(CKPT_PATH)

sentencepiece.SentencePieceProcessor를 사용하여 구성된 CodeGemma tokenizer를 로드합니다.

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

CodeGemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면 gemma.transformer.TransformerConfig를 사용합니다. cache_size 인수는 CodeGemma Transformer 캐시의 시간 단계 수입니다. 그런 다음 flax.linen.Module에서 상속되는 gemma.transformer.Transformer를 사용하여 CodeGemma 모델을 model_2b로 인스턴스화합니다.

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

gemma.sampler.Samplersampler를 만듭니다. CodeGemma 모델 체크포인트와 tokenizer를 사용합니다.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

중간 채우기 (fim) 토큰을 나타내는 변수를 만들고 프롬프트 및 생성된 출력의 형식을 지정하는 도우미 함수를 만듭니다.

예를 들어 다음 코드를 살펴보겠습니다.

def function(string):
assert function('asdf') == 'fdsa'

어설션이 True를 보유하도록 function를 채우려고 합니다. 이 경우 접두사는 다음과 같습니다.

"def function(string):\n"

접미사는 다음과 같습니다.

"assert function('asdf') == 'fdsa'"

그런 다음 이를 프롬프트의 형식으로 PREFIX-SUFFIX-MIDDLE (입력해야 하는 중간 섹션은 항상 프롬프트의 끝에 있음)로 지정합니다.

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

프롬프트를 만들고 추론을 수행합니다. 프리픽스 before 텍스트와 서픽스 after 텍스트를 지정하고 도우미 함수 format_completion prompt를 사용하여 형식이 지정된 프롬프트를 생성합니다.

total_generation_steps (응답을 생성할 때 수행되는 단계 수)를 조정할 수 있습니다. 이 예에서는 100를 사용하여 호스트 메모리를 보존합니다.

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

자세히 알아보기