Google DeepMind의 Gemma 모델 (Gemma Team et al., 2024년). CodeGemma는 Gemini 모델을 만드는 데 사용되는 것과 동일한 연구와 기술로 빌드된 최첨단 경량 개방형 모델군입니다.
Gemma 사전 학습된 모델을 이어받아 CodeGemma 모델은 주로 코드의 5, 000억~1, 0000억 개 토큰을 대상으로 Gemma 모델 제품군과 동일한 아키텍처를 사용하여 추가 학습됩니다. 그 결과 CodeGemma 모델은 강력한 이해 및 추론 기술을 대규모로 유지하면서 완성 및 생성 태스크 모두에서 최신 코드 성능을 달성합니다.
CodeGemma에는 3가지 변형이 있습니다.
- 7B 코드 사전 학습 모델
- 7B 명령 조정 코드 모델
- 코드 채우기 및 개방형 생성을 위해 특별히 학습된 2B 모델입니다.
이 가이드에서는 코드 완성 작업에 Flax와 함께 CodeGemma 모델을 사용하는 방법을 안내합니다.
설정
1. CodeGemma의 Kaggle 액세스 설정
이 튜토리얼을 완료하려면 먼저 Gemma 설정의 설정 안내를 따라 다음 작업을 수행하는 방법을 알아야 합니다.
- kaggle.com에서 CodeGemma에 액세스합니다.
- 리소스가 충분한 Colab 런타임 (T4 GPU의 메모리가 부족하므로 TPU v2를 대신 사용)을 선택하여 CodeGemma 모델을 실행합니다.
- Kaggle 사용자 이름과 API 키를 생성하고 구성합니다.
Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.
2. 환경 변수 설정하기
KAGGLE_USERNAME
및 KAGGLE_KEY
의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 보안 비밀 액세스 권한을 제공하는 데 동의합니다.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. gemma
라이브러리 설치
현재 무료 Colab 하드웨어 가속으로는 이 노트북을 실행하기에 insufficient. Colab Pay As You Go 또는 Colab Pro를 사용하는 경우 수정 > 노트북 설정 > A100 GPU 선택 > 저장을 클릭하여 하드웨어 가속을 사용 설정합니다.
다음으로 github.com/google-deepmind/gemma
에서 Google DeepMind gemma
라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하면 일반적으로 무시해도 됩니다.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. 라이브러리 가져오기
이 노트북에서는 Gemma (Flax를 사용하여 신경망 레이어를 빌드함)와 SentencePiece (토큰화용)를 사용합니다.
import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm
CodeGemma 모델 로드
세 가지 인수를 사용하는 kagglehub.model_download
를 사용하여 CodeGemma 모델을 로드합니다.
handle
: Kaggle의 모델 핸들path
: (선택사항 문자열) 로컬 경로force_download
: (선택사항 부울) 모델을 다시 다운로드하도록 강제합니다.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
모델 가중치 및 토큰라이저의 위치를 확인한 다음 경로 변수를 설정합니다. 토큰라이저 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.
spm.model
토큰라이저 파일은/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
에 있습니다.- 모델 체크포인트는
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
에 있습니다.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
샘플링/추론 실행
gemma.params.load_and_format_params
메서드를 사용하여 CodeGemma 모델 체크포인트를 로드하고 형식을 지정합니다.
params = params_lib.load_and_format_params(CKPT_PATH)
sentencepiece.SentencePieceProcessor
를 사용하여 생성된 CodeGemma 토큰라이저를 로드합니다.
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
CodeGemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면 gemma.deprecated.transformer.TransformerConfig
를 사용하세요. cache_size
인수는 CodeGemma Transformer
캐시의 시간 단계 수입니다. 그런 다음 gemma.deprecated.transformer.Transformer
(flax.linen.Module
에서 상속됨)를 사용하여 CodeGemma 모델을 model_2b
로 인스턴스화합니다.
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
gemma.sampler.Sampler
로 sampler
를 만듭니다. CodeGemma 모델 체크포인트와 토큰라이저를 사용합니다.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
필인더미드 (fim) 토큰을 나타내는 변수를 만들고 프롬프트와 생성된 출력의 형식을 지정하는 도우미 함수를 만듭니다.
예를 들어 다음 코드를 살펴보겠습니다.
def function(string):
assert function('asdf') == 'fdsa'
어설션이 True
를 유지하도록 function
를 채우고자 합니다. 이 경우 접두사는 다음과 같습니다.
"def function(string):\n"
접미사는 다음과 같습니다.
"assert function('asdf') == 'fdsa'"
그런 다음 이를 PREFIX-SUFFIX-MIDDLE 형식으로 프롬프트로 지정합니다 (작성해야 하는 중간 섹션은 항상 프롬프트 끝에 있음).
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
프롬프트를 만들고 추론을 실행합니다. 접두사 before
텍스트와 접미사 after
텍스트를 지정하고 도우미 함수 format_completion prompt
를 사용하여 형식이 지정된 프롬프트를 생성합니다.
total_generation_steps
(응답을 생성할 때 실행되는 단계 수)를 조정할 수 있습니다. 이 예에서는 100
를 사용하여 호스트 메모리를 보존합니다.
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
자세히 알아보기
- 이 튜토리얼에서 사용한 모듈(예:
gemma.params
,gemma.deprecated.transformer
,gemma.sampler
)의 문자열 문서가 포함된 Google DeepMindgemma
라이브러리(GitHub)에 대해 자세히 알아보세요. - 핵심 JAX, Flax, Orbax 라이브러리에는 자체 문서 사이트가 있습니다.
sentencepiece
토큰라이저/역토큰라이저 문서는 Google의sentencepiece
GitHub 저장소를 참고하세요.kagglehub
문서는 Kaggle의kagglehub
GitHub 저장소에서README.md
를 확인하세요.- Google Cloud Vertex AI에서 Gemma 모델을 사용하는 방법을 알아보세요.
- Google Cloud TPU (v3-8 이상)를 사용하는 경우 최신
jax[tpu]
패키지 (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)로 업데이트하고 런타임을 다시 시작하고jax
및jaxlib
버전이 일치하는지 확인합니다 (!pip list | grep jax
). 이렇게 하면jaxlib
및jax
버전 불일치로 인해 발생할 수 있는RuntimeError
를 방지할 수 있습니다. JAX 설치 안내는 JAX 문서를 참고하세요.