ai.google.dev에서 보기 | Google Colab에서 실행 | GitHub에서 소스 보기 |
Google DeepMind의 Gemma 모델 (Gemma Team 외, 2024년). CodeGemma는 Gemini 모델을 만드는 데 사용된 것과 동일한 연구와 기술을 바탕으로 빌드된 최첨단 경량 개방형 모델 제품군입니다.
Gemma 선행 학습된 모델에서 계속 진행하면 CodeGemma 모델은 주로 500~1,000억 개 이상의 토큰으로 구성된 Gemma 모델군과 동일한 아키텍처를 사용합니다 결과적으로 CodeGemma 모델은 완성된 두 코드 모두에서 최고의 코드 성능을 달성합니다. 생성 작업을 자동화하는 동시에 대규모 언어 모델을 학습시킬 수 있습니다
CodeGemma에는 세 가지 변형이 있습니다.
- 7B 코드 선행 학습된 모델
- 7B 명령 조정 코드 모델
- 코드 입력 및 개방형 생성을 위해 특별히 학습된 2B 모델입니다.
이 가이드에서는 코드 완성 작업을 위해 Flax와 함께 CodeGemma 모델을 사용하는 방법을 설명합니다.
설정
1. CodeGemma에 Kaggle 액세스 권한 설정하기
이 튜토리얼을 완료하려면 먼저 Gemma 설정에서 다음을 수행하는 방법을 보여주는 설정 안내를 따라야 합니다.
- kaggle.com에서 CodeGemma에 액세스하세요.
- 리소스가 충분한 Colab 런타임 (T4 GPU 메모리가 부족합니다. 대신 TPU v2를 사용하세요)을 선택하여 CodeGemma 모델을 실행합니다.
- Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.
Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.
2. 환경 변수 설정하기
KAGGLE_USERNAME
및 KAGGLE_KEY
의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 비밀 액세스 제공에 동의해야 합니다.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. gemma
라이브러리 설치
현재 무료 Colab 하드웨어 가속으로는 이 노트북을 실행할 수 없습니다. Colab 종량제 또는 Colab Pro를 사용하는 경우 수정을 클릭합니다. 노트북 설정 > A100 GPU를 선택합니다. 저장하여 하드웨어 가속을 사용 설정합니다.
다음으로 github.com/google-deepmind/gemma
에서 Google DeepMind gemma
라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하면 일반적으로 무시해도 됩니다.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. 라이브러리 가져오기
이 노트북은 Gemma (Flax를 사용하여 신경망 레이어를 빌드하는)와 SentencePiece (토큰화용)를 사용합니다.
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
CodeGemma 모델 로드
다음 세 가지 인수를 사용하는 kagglehub.model_download
를 사용하여 CodeGemma 모델을 로드합니다.
handle
: Kaggle의 모델 핸들path
: (선택사항 문자열) 로컬 경로force_download
: (선택적 불리언) 모델을 강제로 다시 다운로드합니다.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
모델 가중치와 tokenizer의 위치를 확인한 다음 경로 변수를 설정합니다. tokenizer 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.
spm.model
tokenizer 파일은/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
에 있습니다.- 모델 체크포인트가
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
에 위치하게 됩니다.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
샘플링/추론 수행
gemma.params.load_and_format_params
메서드를 사용하여 CodeGemma 모델 체크포인트를 로드하고 형식을 지정합니다.
params = params_lib.load_and_format_params(CKPT_PATH)
sentencepiece.SentencePieceProcessor
를 사용하여 구성된 CodeGemma tokenizer를 로드합니다.
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
CodeGemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면 gemma.transformer.TransformerConfig
를 사용합니다. cache_size
인수는 CodeGemma Transformer
캐시의 시간 단계 수입니다. 그런 다음 flax.linen.Module
에서 상속되는 gemma.transformer.Transformer
를 사용하여 CodeGemma 모델을 model_2b
로 인스턴스화합니다.
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
gemma.sampler.Sampler
로 sampler
를 만듭니다. CodeGemma 모델 체크포인트와 tokenizer를 사용합니다.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
중간 채우기 (fim) 토큰을 나타내는 변수를 만들고 프롬프트 및 생성된 출력의 형식을 지정하는 도우미 함수를 만듭니다.
예를 들어 다음 코드를 살펴보겠습니다.
def function(string):
assert function('asdf') == 'fdsa'
어설션이 True
를 보유하도록 function
를 채우려고 합니다. 이 경우 접두사는 다음과 같습니다.
"def function(string):\n"
접미사는 다음과 같습니다.
"assert function('asdf') == 'fdsa'"
그런 다음 이를 프롬프트의 형식으로 PREFIX-SUFFIX-MIDDLE (입력해야 하는 중간 섹션은 항상 프롬프트의 끝에 있음)로 지정합니다.
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
프롬프트를 만들고 추론을 수행합니다. 프리픽스 before
텍스트와 서픽스 after
텍스트를 지정하고 도우미 함수 format_completion prompt
를 사용하여 형식이 지정된 프롬프트를 생성합니다.
total_generation_steps
(응답을 생성할 때 수행되는 단계 수)를 조정할 수 있습니다. 이 예에서는 100
를 사용하여 호스트 메모리를 보존합니다.
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
자세히 알아보기
- 이 튜토리얼에서 사용한 모듈의 docstring이 포함된
gemma.params
, GitHub의gemma
라이브러리에 대해 자세히 알아볼 수 있습니다.gemma.transformer
및gemma.sampler
. - core JAX, Flax, Orbax 라이브러리에는 자체 문서 사이트가 있습니다.
sentencepiece
tokenizer/detokenizer 문서는 Google의sentencepiece
GitHub 저장소를 확인하세요.kagglehub
문서는 Kaggle의kagglehub
GitHub 저장소에서README.md
를 확인하세요.- Google Cloud Vertex AI에서 Gemma 모델을 사용하는 방법 알아보기
- Google Cloud TPU (v3-8 이상)를 사용하는 경우 최신
jax[tpu]
패키지 (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)로 업데이트하고 런타임을 다시 시작한 다음jax
및jaxlib
버전이 일치하는지 (!pip list | grep jax
) 확인합니다. 이렇게 하면jaxlib
및jax
버전 불일치로 인해 발생할 수 있는RuntimeError
을 방지할 수 있습니다. JAX 설치에 대한 자세한 안내는 JAX 문서를 참조하세요.