PyTorch를 사용하여 Gemma 실행

이 가이드에서는 PyTorch 프레임워크를 사용하여 Gemma를 실행하는 방법을 보여줍니다. 여기에는 Gemma 버전 3 이상 모델을 프롬프트하는 데 이미지 데이터를 사용하는 방법이 포함됩니다. Gemma PyTorch 구현에 관한 자세한 내용은 프로젝트 저장소 리드미를 참고하세요.

설정

다음 섹션에서는 Kaggle에서 다운로드할 Gemma 모델에 액세스하는 방법, 인증 변수 설정, 종속 항목 설치, 패키지 가져오기 등 개발 환경을 설정하는 방법을 설명합니다.

시스템 요구사항

이 Gemma Pytorch 라이브러리에서는 Gemma 모델을 실행하기 위해 GPU 또는 TPU 프로세서가 필요합니다. 표준 Colab CPU Python 런타임과 T4 GPU Python 런타임은 Gemma 1B, 2B, 4B 크기 모델을 실행하는 데 충분합니다. 다른 GPU 또는 TPU의 고급 사용 사례는 Gemma PyTorch 저장소의 README를 참고하세요.

Kaggle에서 Gemma에 액세스하기

이 튜토리얼을 완료하려면 먼저 Gemma 설정의 설정 안내를 따라 다음 작업을 수행하는 방법을 알아야 합니다.

  • kaggle.com에서 Gemma에 액세스합니다.
  • Gemma 모델을 실행하는 데 충분한 리소스가 있는 Colab 런타임을 선택합니다.
  • Kaggle 사용자 이름과 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 보안 비밀 액세스 권한을 제공하는 데 동의합니다.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

종속 항목 설치

pip install -q -U torch immutabledict sentencepiece

모델 가중치 다운로드

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '4b':
  CONFIG = '4b-v1'
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

모델의 토큰 생성기 및 체크포인트 경로를 설정합니다.

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

실행 환경 구성

다음 섹션에서는 Gemma를 실행하기 위한 PyTorch 환경을 준비하는 방법을 설명합니다.

PyTorch 실행 환경 준비

Gemma Pytorch 저장소를 클론하여 PyTorch 모델 실행 환경을 준비합니다.

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

모델 구성 설정

모델을 실행하기 전에 Gemma 변형, 토큰라이저, 정량화 수준을 비롯한 몇 가지 구성 매개변수를 설정해야 합니다.

# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

기기 컨텍스트 구성

다음 코드는 모델을 실행하기 위한 기기 컨텍스트를 구성합니다.

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

모델 인스턴스화 및 로드

요청 실행을 준비하기 위해 가중치가 있는 모델을 로드합니다.

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

추론 실행

다음은 채팅 모드에서 생성하고 여러 요청으로 생성하는 예입니다.

명령 조정 Gemma 모델은 학습 및 추론 중에 모두 명령 조정 예시를 추가 정보로 주석 처리하는 특정 형식 지정자로 학습되었습니다. 주석은 (1) 대화에서 역할을 나타내고 (2) 대화의 차례를 구분합니다.

관련 주석 토큰은 다음과 같습니다.

  • user: 사용자 차례
  • model: 모델 회전
  • <start_of_turn>: 대화 시작
  • <start_of_image>: 이미지 데이터 입력 태그
  • <end_of_turn><eos>: 대화 턴 종료

자세한 내용은 [여기](https://ai.google.dev/gemma/core/prompt-structure)에서 안내를 조정된 Gemma 모델의 프롬프트 형식 지정에 관해 알아보세요.

텍스트로 텍스트 생성

다음은 여러 번의 대화에서 사용자 및 모델 채팅 템플릿을 사용하여 안내 조정 Gemma 모델의 프롬프트 형식을 지정하는 방법을 보여주는 샘플 코드 스니펫입니다.

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

이미지로 텍스트 생성

Gemma 버전 3 이상에서는 프롬프트와 함께 이미지를 사용할 수 있습니다. 다음 예는 프롬프트에 시각적 데이터를 포함하는 방법을 보여줍니다.

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)

print(model.generate(
    [['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
    device=device,
    output_len=OUTPUT_LEN,
))

자세히 알아보기

이제 Pytorch에서 Gemma를 사용하는 방법을 배웠으므로 ai.google.dev/gemma에서 Gemma로 할 수 있는 다른 많은 작업을 살펴볼 수 있습니다. 다음과 같은 관련 리소스도 참고하세요.