ai.google.dev에서 보기 | Google Colab에서 실행 | Vertex AI에서 열기 | GitHub에서 소스 보기 |
이 튜토리얼에서는 JAX (고성능 수치 컴퓨팅 라이브러리), Flax (JAX 기반 신경망 라이브러리), Orbax (학습 유틸리티 (예: 체크포인트_포인트_채택)와 같은 JAX 기반 신경망 라이브러리인 Orbax로 작성된 Google DeepMind의 recurrentgemma
라이브러리를 사용하여 RecurrentGemma 2B Instruct 모델로 기본 샘플링/추론을 수행하는 방법을 보여줍니다.SentencePiece Flax는 이 노트북에서 직접 사용되지는 않지만 Gemma와 RecurrentGemma (Griffin 모델)를 만드는 데 Flax를 사용했습니다.
이 노트북은 T4 GPU를 사용하는 Google Colab에서 실행할 수 있습니다 (수정 > 노트북 설정으로 이동한 후 하드웨어 가속기에서 T4 GPU 선택).
설정
다음 섹션에서는 모델 액세스, API 키 가져오기, 노트북 런타임 구성 등 RecurrentGemma 모델을 사용하기 위해 노트북을 준비하는 단계를 설명합니다.
Gemma에 Kaggle 액세스 권한 설정하기
이 튜토리얼을 완료하려면 먼저 Gemma 설정과 비슷한 설정 안내를 따라야 하지만 몇 가지 예외가 있습니다.
- kaggle.com에서 Gemma 대신 RecurrentGemma에 액세스하세요.
- RecurrentGemma 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임을 선택하세요.
- Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.
RecurrentGemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.
환경 변수 설정하기
KAGGLE_USERNAME
및 KAGGLE_KEY
의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 보안 비밀 액세스 제공에 동의합니다.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
recurrentgemma
라이브러리 설치
이 노트북에서는 무료 Colab GPU 사용에 중점을 둡니다. 하드웨어 가속을 사용 설정하려면 수정 > 노트북 설정 > T4 GPU 선택 > 저장을 클릭합니다.
다음으로 github.com/google-deepmind/recurrentgemma
에서 Google DeepMind recurrentgemma
라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하면 일반적으로 무시해도 됩니다.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
RecurrentGemma 모델 로드 및 준비
- 세 가지 인수를 사용하는
kagglehub.model_download
를 사용하여 RecurrentGemma 모델을 로드합니다.
handle
: Kaggle의 모델 핸들path
: (선택사항 문자열) 로컬 경로force_download
: (선택적 불리언) 모델을 강제로 다시 다운로드합니다.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- 모델 가중치와 tokenizer의 위치를 확인한 다음 경로 변수를 설정합니다. tokenizer 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.
tokenizer.model
파일은/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
에 있습니다.- 모델 체크포인트는
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
에 있습니다.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
샘플링/추론 수행
recurrentgemma.jax.load_parameters
메서드로 RecurrentGemma 모델 체크포인트를 로드합니다."single_device"
로 설정된sharding
인수는 단일 기기에 모든 모델 매개변수를 로드합니다.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
sentencepiece.SentencePieceProcessor
를 사용하여 구성된 RecurrentGemma 모델 tokenizer를 로드합니다.
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- RecurrentGemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면
recurrentgemma.GriffinConfig.from_flax_params_or_variables
를 사용합니다. 그런 다음recurrentgemma.jax.Griffin
를 사용하여 Griffin 모델을 인스턴스화합니다.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- RecurrentGemma 모델 체크포인트/가중치 및 tokenizer 위에
recurrentgemma.jax.Sampler
를 사용하여sampler
를 만듭니다.
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
prompt
로 프롬프트를 작성하고 추론을 수행합니다.total_generation_steps
(응답을 생성할 때 수행되는 단계 수)를 조정할 수 있습니다. 이 예에서는50
를 사용하여 호스트 메모리를 보존합니다.
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
자세히 알아보기
recurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
,recurrentgemma.jax.Sampler
등 이 튜토리얼에서 사용한 메서드 및 모듈의 docstring이 포함된 Google DeepMindrecurrentgemma
라이브러리에 관해 자세히 알아볼 수 있습니다.- core JAX, Flax, Orbax 라이브러리에는 자체 문서 사이트가 있습니다.
sentencepiece
tokenizer/detokenizer 문서는 Google의sentencepiece
GitHub 저장소를 확인하세요.kagglehub
문서는 Kaggle의kagglehub
GitHub 저장소에서README.md
를 확인하세요.- Google Cloud Vertex AI에서 Gemma 모델을 사용하는 방법 알아보기
- Google DeepMind의 RecurrentGemma: Moving Past Transformers for Efficient Open Language Models 문서를 확인하세요.
- RecurrentGemma에서 사용하는 모델 아키텍처에 대한 자세한 내용은 GoogleDeepMind의 Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models(효율적인 언어 모델을 위한 로컬 어텐션과 결합) 자료를 읽어보세요.