RecurrentGemma를 사용한 추론

ai.google.dev에서 보기 Google Colab에서 실행 Vertex AI에서 열기 GitHub에서 소스 보기

이 튜토리얼에서는 JAX (고성능 수치 컴퓨팅 라이브러리), Flax (JAX 기반 신경망 라이브러리), Orbax (학습 유틸리티 (예: 체크포인트_포인트_채택)와 같은 JAX 기반 신경망 라이브러리인 Orbax로 작성된 Google DeepMind의 recurrentgemma 라이브러리를 사용하여 RecurrentGemma 2B Instruct 모델로 기본 샘플링/추론을 수행하는 방법을 보여줍니다.SentencePiece Flax는 이 노트북에서 직접 사용되지는 않지만 Gemma와 RecurrentGemma (Griffin 모델)를 만드는 데 Flax를 사용했습니다.

이 노트북은 T4 GPU를 사용하는 Google Colab에서 실행할 수 있습니다 (수정 > 노트북 설정으로 이동한 후 하드웨어 가속기에서 T4 GPU 선택).

설정

다음 섹션에서는 모델 액세스, API 키 가져오기, 노트북 런타임 구성 등 RecurrentGemma 모델을 사용하기 위해 노트북을 준비하는 단계를 설명합니다.

Gemma에 Kaggle 액세스 권한 설정하기

이 튜토리얼을 완료하려면 먼저 Gemma 설정비슷한 설정 안내를 따라야 하지만 몇 가지 예외가 있습니다.

  • kaggle.com에서 Gemma 대신 RecurrentGemma에 액세스하세요.
  • RecurrentGemma 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임을 선택하세요.
  • Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.

RecurrentGemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 보안 비밀 액세스 제공에 동의합니다.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

recurrentgemma 라이브러리 설치

이 노트북에서는 무료 Colab GPU 사용에 중점을 둡니다. 하드웨어 가속을 사용 설정하려면 수정 > 노트북 설정 > T4 GPU 선택 > 저장을 클릭합니다.

다음으로 github.com/google-deepmind/recurrentgemma에서 Google DeepMind recurrentgemma 라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하면 일반적으로 무시해도 됩니다.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

RecurrentGemma 모델 로드 및 준비

  1. 세 가지 인수를 사용하는 kagglehub.model_download를 사용하여 RecurrentGemma 모델을 로드합니다.
  • handle: Kaggle의 모델 핸들
  • path: (선택사항 문자열) 로컬 경로
  • force_download: (선택적 불리언) 모델을 강제로 다시 다운로드합니다.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. 모델 가중치와 tokenizer의 위치를 확인한 다음 경로 변수를 설정합니다. tokenizer 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.
  • tokenizer.model 파일은 /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1에 있습니다.
  • 모델 체크포인트는 /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it에 있습니다.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

샘플링/추론 수행

  1. recurrentgemma.jax.load_parameters 메서드로 RecurrentGemma 모델 체크포인트를 로드합니다. "single_device"로 설정된 sharding 인수는 단일 기기에 모든 모델 매개변수를 로드합니다.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. sentencepiece.SentencePieceProcessor를 사용하여 구성된 RecurrentGemma 모델 tokenizer를 로드합니다.
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. RecurrentGemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면 recurrentgemma.GriffinConfig.from_flax_params_or_variables를 사용합니다. 그런 다음 recurrentgemma.jax.Griffin를 사용하여 Griffin 모델을 인스턴스화합니다.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. RecurrentGemma 모델 체크포인트/가중치 및 tokenizer 위에 recurrentgemma.jax.Sampler를 사용하여 sampler를 만듭니다.
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. prompt로 프롬프트를 작성하고 추론을 수행합니다. total_generation_steps (응답을 생성할 때 수행되는 단계 수)를 조정할 수 있습니다. 이 예에서는 50를 사용하여 호스트 메모리를 보존합니다.
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

자세히 알아보기