JAX 및 Flax를 사용하여 CodeGemma로 추론

Google DeepMind의 Gemma 모델 (Gemma Team et al., 2024년). CodeGemma는 Gemini 모델을 만드는 데 사용되는 것과 동일한 연구와 기술로 빌드된 최첨단 경량 개방형 모델군입니다.

Gemma 사전 학습된 모델을 이어받아 CodeGemma 모델은 주로 코드의 5, 000억~1, 0000억 개 토큰을 대상으로 Gemma 모델 제품군과 동일한 아키텍처를 사용하여 추가 학습됩니다. 그 결과 CodeGemma 모델은 강력한 이해 및 추론 기술을 대규모로 유지하면서 완성 및 생성 태스크 모두에서 최신 코드 성능을 달성합니다.

CodeGemma에는 3가지 변형이 있습니다.

  • 7B 코드 사전 학습 모델
  • 7B 명령 조정 코드 모델
  • 코드 채우기 및 개방형 생성을 위해 특별히 학습된 2B 모델입니다.

이 가이드에서는 코드 완성 작업에 Flax와 함께 CodeGemma 모델을 사용하는 방법을 안내합니다.

설정

1. CodeGemma의 Kaggle 액세스 설정

이 튜토리얼을 완료하려면 먼저 Gemma 설정의 설정 안내를 따라 다음 작업을 수행하는 방법을 알아야 합니다.

  • kaggle.com에서 CodeGemma에 액세스합니다.
  • 리소스가 충분한 Colab 런타임 (T4 GPU의 메모리가 부족하므로 TPU v2를 대신 사용)을 선택하여 CodeGemma 모델을 실행합니다.
  • Kaggle 사용자 이름과 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

2. 환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 보안 비밀 액세스 권한을 제공하는 데 동의합니다.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma 라이브러리 설치

현재 무료 Colab 하드웨어 가속으로는 이 노트북을 실행하기에 insufficient. Colab Pay As You Go 또는 Colab Pro를 사용하는 경우 수정 > 노트북 설정 > A100 GPU 선택 > 저장을 클릭하여 하드웨어 가속을 사용 설정합니다.

다음으로 github.com/google-deepmind/gemma에서 Google DeepMind gemma 라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하면 일반적으로 무시해도 됩니다.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. 라이브러리 가져오기

이 노트북에서는 Gemma (Flax를 사용하여 신경망 레이어를 빌드함)와 SentencePiece (토큰화용)를 사용합니다.

import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm

CodeGemma 모델 로드

세 가지 인수를 사용하는 kagglehub.model_download를 사용하여 CodeGemma 모델을 로드합니다.

  • handle: Kaggle의 모델 핸들
  • path: (선택사항 문자열) 로컬 경로
  • force_download: (선택사항 부울) 모델을 다시 다운로드하도록 강제합니다.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

모델 가중치 및 토큰라이저의 위치를 확인한 다음 경로 변수를 설정합니다. 토큰라이저 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.

  • spm.model 토큰라이저 파일은 /LOCAL/PATH/TO/codegemma/flax/2b-pt/3에 있습니다.
  • 모델 체크포인트는 /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt에 있습니다.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

샘플링/추론 실행

gemma.params.load_and_format_params 메서드를 사용하여 CodeGemma 모델 체크포인트를 로드하고 형식을 지정합니다.

params = params_lib.load_and_format_params(CKPT_PATH)

sentencepiece.SentencePieceProcessor를 사용하여 생성된 CodeGemma 토큰라이저를 로드합니다.

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

CodeGemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면 gemma.deprecated.transformer.TransformerConfig를 사용하세요. cache_size 인수는 CodeGemma Transformer 캐시의 시간 단계 수입니다. 그런 다음 gemma.deprecated.transformer.Transformer (flax.linen.Module에서 상속됨)를 사용하여 CodeGemma 모델을 model_2b로 인스턴스화합니다.

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

gemma.sampler.Samplersampler를 만듭니다. CodeGemma 모델 체크포인트와 토큰라이저를 사용합니다.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

필인더미드 (fim) 토큰을 나타내는 변수를 만들고 프롬프트와 생성된 출력의 형식을 지정하는 도우미 함수를 만듭니다.

예를 들어 다음 코드를 살펴보겠습니다.

def function(string):
assert function('asdf') == 'fdsa'

어설션이 True를 유지하도록 function를 채우고자 합니다. 이 경우 접두사는 다음과 같습니다.

"def function(string):\n"

접미사는 다음과 같습니다.

"assert function('asdf') == 'fdsa'"

그런 다음 이를 PREFIX-SUFFIX-MIDDLE 형식으로 프롬프트로 지정합니다 (작성해야 하는 중간 섹션은 항상 프롬프트 끝에 있음).

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

프롬프트를 만들고 추론을 실행합니다. 접두사 before 텍스트와 접미사 after 텍스트를 지정하고 도우미 함수 format_completion prompt를 사용하여 형식이 지정된 프롬프트를 생성합니다.

total_generation_steps (응답을 생성할 때 실행되는 단계 수)를 조정할 수 있습니다. 이 예에서는 100를 사용하여 호스트 메모리를 보존합니다.

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

자세히 알아보기