PyTorch의 젬마

PyTorch에서 Gemma 추론을 실행하는 방법을 보여주는 간단한 데모입니다. 자세한 내용은 여기에서 공식 PyTorch 구현에 관한 GitHub 저장소를 확인하세요.

다음 사항에 유의하세요.

  • 무료 Colab CPU Python 런타임과 T4 GPU Python 런타임은 Gemma 2B 모델과 7B int8 정규화 모델을 실행하는 데 충분합니다.
  • 다른 GPU 또는 TPU의 고급 사용 사례는 공식 저장소의 README.md를 참고하세요.

1. Gemma의 Kaggle 액세스 설정

이 튜토리얼을 완료하려면 먼저 Gemma 설정의 설정 안내를 따라 다음 작업을 수행하는 방법을 알아야 합니다.

  • kaggle.com에서 Gemma에 액세스합니다.
  • Gemma 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임을 선택하세요.
  • Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

2. 환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 보안 비밀 액세스 권한을 제공하는 데 동의합니다.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

종속 항목 설치

pip install -q -U torch immutabledict sentencepiece

모델 가중치 다운로드

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'
import os
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

모델 구현 다운로드

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

모델 설정

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

추론 실행

다음은 채팅 모드에서 생성하고 여러 요청으로 생성하는 예입니다.

명령 조정 Gemma 모델은 학습 및 추론 중에 모두 명령 조정 예시를 추가 정보로 주석 처리하는 특정 형식 지정자로 학습되었습니다. 주석은 (1) 대화에서 역할을 나타내고 (2) 대화의 차례를 구분합니다.

관련 주석 토큰은 다음과 같습니다.

  • user: 사용자 차례
  • model: 모델 회전
  • <start_of_turn>: 대화의 시작 부분
  • <end_of_turn><eos>: 대화 차례의 끝

자세한 내용은 여기에서 명령어 조정 Gemma 모델의 프롬프트 형식 지정에 관해 알아보세요.

다음은 여러 번의 대화에서 사용자 및 모델 채팅 템플릿을 사용하여 명령 조정 Gemma 모델의 프롬프트 형식을 지정하는 방법을 보여주는 샘플 코드 스니펫입니다.

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

자세히 알아보기

이제 Pytorch에서 Gemma를 사용하는 방법을 알게 되었으므로 ai.google.dev/gemma에서 Gemma로 할 수 있는 다른 많은 작업을 살펴볼 수 있습니다. 다음과 같은 관련 리소스도 참고하세요.