JAX 및 Flax를 사용하여 Gemma로 추론

ai.google.dev에서 보기 Google Colab에서 실행 Vertex AI에서 열기 GitHub에서 소스 보기

개요

Gemma는 Google DeepMind Gemini 연구 및 기술을 기반으로 하는 최첨단 경량 개방형 대규모 언어 모델 제품군입니다. 이 튜토리얼에서는 JAX (고성능 수치 컴퓨팅 라이브러리), Flax (JAX 기반 신경망 라이브러리), Orbax (체크포인트 생성과 같은 학습 유틸리티용 JAX 기반 라이브러리), {/1SentencePiece 이 노트북에서는 Flax를 직접 사용하지 않지만 Gemma를 만드는 데 Flax를 사용했습니다.

이 노트북은 무료 T4 GPU를 갖춘 Google Colab에서 실행할 수 있습니다 (수정 > 노트북 설정으로 이동한 후 하드웨어 가속기에서 T4 GPU 선택).

설정

1. Gemma에 Kaggle 액세스 권한 설정하기

이 튜토리얼을 완료하려면 먼저 Gemma 설정에서 다음 작업을 수행하는 방법을 보여주는 설정 안내를 따라야 합니다.

  • kaggle.com에서 Gemma에 액세스하세요.
  • Gemma 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임을 선택하세요.
  • Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

2. 환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다. '액세스 권한을 부여하시겠습니까?'라는 메시지가 표시되면 비밀 액세스 제공에 동의해야 합니다.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma 라이브러리 설치

이 노트북에서는 무료 Colab GPU 사용에 중점을 둡니다. 하드웨어 가속을 사용하려면 수정 > 노트북 설정 > T4 GPU >를 선택합니다. 저장합니다.

다음으로 github.com/google-deepmind/gemma에서 Google DeepMind gemma 라이브러리를 설치해야 합니다. 'pip의 종속 항목 리졸버'에 관한 오류가 발생하면 일반적으로 무시해도 됩니다.

pip install -q git+https://github.com/google-deepmind/gemma.git

Gemma 모델 로드 및 준비

  1. 다음 세 가지 인수를 사용하는 kagglehub.model_download를 사용하여 Gemma 모델을 로드합니다.
  • handle: Kaggle의 모델 핸들
  • path: (선택사항 문자열) 로컬 경로
  • force_download: (선택적 불리언) 모델을 강제로 다시 다운로드합니다.
GEMMA_VARIANT = 'gemma2-2b-it' # @param ['gemma2-2b', 'gemma2-2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma-2/flax/{GEMMA_VARIANT}')
<ph type="x-smartling-placeholder">
Downloading 11 files:   0%|          | 0/11 [00:00<?, ?it/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/manifest.ocdbt...
100%|██████████| 180/180 [00:00<00:00, 101kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/d/b5a4695f4be0a2f41ec1e25616ebd7e7...
100%|██████████| 2.66k/2.66k [00:00<00:00, 5.36MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/descriptor/descriptor.pbtxt...
100%|██████████| 45.0/45.0 [00:00<00:00, 90.0kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/_METADATA...
100%|██████████| 55.3k/55.3k [00:00<00:00, 29.5MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/_CHECKPOINT_METADATA...
100%|██████████| 92.0/92.0 [00:00<00:00, 234kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/bf69258061ae5f35eb7a5669fe6877d4...
0%|          | 0.00/2.12G [00:00<?, ?B/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/fc20151969d7ca91ea9d8275bda0e219...
100%|██████████| 2.64k/2.64k [00:00<00:00, 5.58MB/s]

  0%|          | 2.00M/2.12G [00:00<01:48, 20.8MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/834bb4bf1e3854eb09f6208c95c071b2...
0%|          | 0.00/1.70G [00:00<?, ?B/s]
  0%|          | 9.00M/2.12G [00:00<00:46, 48.2MB/s]

  0%|          | 3.00M/1.70G [00:00<01:06, 27.6MB/s]
  1%|          | 14.0M/2.12G [00:00<00:46, 48.6MB/s]

  1%|          | 9.00M/1.70G [00:00<00:40, 44.5MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/manifest.ocdbt...
100%|██████████| 118/118 [00:00<00:00, 303kB/s]

  1%|          | 21.0M/2.12G [00:00<00:41, 53.7MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/checkpoint...
0%|          | 0.00/22.5k [00:00<?, ?B/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/tokenizer.model...
100%|██████████| 22.5k/22.5k [00:00<00:00, 24.7MB/s]


  1%|          | 17.0M/1.70G [00:00<00:36, 49.5MB/s]


  0%|          | 0.00/4.04M [00:00<?, ?B/s]
100%|██████████| 4.04M/4.04M [00:00<00:00, 64.6MB/s]


  1%|▏         | 24.0M/1.70G [00:00<00:34, 52.7MB/s]
  2%|▏         | 40.0M/2.12G [00:00<00:34, 64.6MB/s]

  2%|▏         | 33.0M/1.70G [00:00<00:27, 64.4MB/s]
  2%|▏         | 49.0M/2.12G [00:00<00:34, 64.8MB/s]

  3%|▎         | 47.0M/1.70G [00:00<00:20, 86.9MB/s]
  3%|▎         | 59.0M/2.12G [00:00<00:29, 74.4MB/s]

  3%|▎         | 56.0M/1.70G [00:00<00:24, 73.1MB/s]
  3%|▎         | 67.0M/2.12G [00:01<00:31, 70.1MB/s]

  4%|▎         | 64.0M/1.70G [00:01<00:25, 69.4MB/s]
  3%|▎         | 74.0M/2.12G [00:01<00:32, 67.4MB/s]

  4%|▍         | 73.0M/1.70G [00:01<00:23, 75.7MB/s]
  4%|▍         | 84.0M/2.12G [00:01<00:28, 75.5MB/s]

  5%|▍         | 81.0M/1.70G [00:01<00:22, 77.7MB/s]

  5%|▌         | 95.0M/1.70G [00:01<00:17, 96.5MB/s]
  4%|▍         | 92.0M/2.12G [00:01<00:38, 56.7MB/s]

  6%|▌         | 106M/1.70G [00:01<00:17, 101MB/s]  
  5%|▍         | 102M/2.12G [00:01<00:32, 67.2MB/s] 

  7%|▋         | 117M/1.70G [00:01<00:16, 102MB/s]
  5%|▌         | 110M/2.12G [00:01<00:30, 70.5MB/s]

  7%|▋         | 128M/1.70G [00:01<00:16, 105MB/s]
  5%|▌         | 119M/2.12G [00:01<00:28, 75.6MB/s]

  8%|▊         | 142M/1.70G [00:01<00:14, 117MB/s]
  6%|▌         | 129M/2.12G [00:02<00:30, 70.1MB/s]

  9%|▉         | 154M/1.70G [00:01<00:17, 92.5MB/s]
  6%|▋         | 138M/2.12G [00:02<00:28, 73.9MB/s]

  9%|▉         | 164M/1.70G [00:02<00:18, 87.7MB/s]
  7%|▋         | 146M/2.12G [00:02<00:30, 70.2MB/s]

 10%|▉         | 173M/1.70G [00:02<00:20, 81.7MB/s]
  7%|▋         | 153M/2.12G [00:02<00:33, 63.0MB/s]

 10%|█         | 182M/1.70G [00:02<00:19, 82.8MB/s]
  8%|▊         | 164M/2.12G [00:02<00:27, 75.3MB/s]

 11%|█         | 195M/1.70G [00:02<00:17, 90.8MB/s]
  8%|▊         | 174M/2.12G [00:02<00:25, 82.0MB/s]

 12%|█▏        | 207M/1.70G [00:02<00:16, 99.0MB/s]
  9%|▊         | 186M/2.12G [00:02<00:22, 92.9MB/s]

 13%|█▎        | 218M/1.70G [00:02<00:16, 99.7MB/s]
  9%|▉         | 196M/2.12G [00:02<00:22, 92.4MB/s]

 13%|█▎        | 229M/1.70G [00:02<00:15, 103MB/s] 
 10%|▉         | 206M/2.12G [00:02<00:22, 92.1MB/s]

 14%|█▎        | 239M/1.70G [00:02<00:15, 99.4MB/s]
 10%|▉         | 215M/2.12G [00:03<00:22, 91.3MB/s]

 14%|█▍        | 250M/1.70G [00:03<00:15, 101MB/s] 
 10%|█         | 226M/2.12G [00:03<00:21, 96.5MB/s]

 15%|█▌        | 263M/1.70G [00:03<00:14, 108MB/s]
 11%|█         | 238M/2.12G [00:03<00:19, 105MB/s] 
 11%|█▏        | 249M/2.12G [00:03<00:19, 103MB/s]

 16%|█▌        | 274M/1.70G [00:03<00:16, 91.6MB/s]
 12%|█▏        | 259M/2.12G [00:03<00:21, 93.3MB/s]

 16%|█▋        | 284M/1.70G [00:03<00:20, 76.0MB/s]
 12%|█▏        | 269M/2.12G [00:03<00:21, 94.3MB/s]

 17%|█▋        | 295M/1.70G [00:03<00:18, 84.0MB/s]
 13%|█▎        | 279M/2.12G [00:03<00:20, 94.2MB/s]

 17%|█▋        | 304M/1.70G [00:03<00:17, 84.4MB/s]
 13%|█▎        | 289M/2.12G [00:03<00:20, 94.9MB/s]

 18%|█▊        | 313M/1.70G [00:03<00:18, 81.9MB/s]
 14%|█▍        | 299M/2.12G [00:03<00:21, 91.4MB/s]
 14%|█▍        | 308M/2.12G [00:04<00:21, 89.4MB/s]

 18%|█▊        | 322M/1.70G [00:03<00:20, 73.5MB/s]

 19%|█▉        | 330M/1.70G [00:04<00:19, 74.9MB/s]
 15%|█▍        | 317M/2.12G [00:04<00:23, 81.1MB/s]
 15%|█▌        | 326M/2.12G [00:04<00:23, 83.6MB/s]

 19%|█▉        | 338M/1.70G [00:04<00:20, 72.0MB/s]

 20%|█▉        | 346M/1.70G [00:04<00:19, 74.6MB/s]
 15%|█▌        | 335M/2.12G [00:04<00:24, 79.2MB/s]

 20%|██        | 354M/1.70G [00:04<00:19, 75.0MB/s]
 16%|█▌        | 344M/2.12G [00:04<00:23, 81.4MB/s]
 16%|█▋        | 352M/2.12G [00:04<00:28, 67.3MB/s]

 21%|██        | 362M/1.70G [00:04<00:26, 54.3MB/s]
 17%|█▋        | 359M/2.12G [00:04<00:31, 59.6MB/s]

 21%|██        | 369M/1.70G [00:04<00:26, 53.4MB/s]
 17%|█▋        | 366M/2.12G [00:05<00:31, 59.0MB/s]

 22%|██▏       | 375M/1.70G [00:04<00:26, 54.9MB/s]
 17%|█▋        | 372M/2.12G [00:05<00:31, 59.2MB/s]

 22%|██▏       | 381M/1.70G [00:05<00:25, 56.2MB/s]
 17%|█▋        | 379M/2.12G [00:05<00:30, 62.3MB/s]

 22%|██▏       | 388M/1.70G [00:05<00:24, 56.8MB/s]
 18%|█▊        | 386M/2.12G [00:05<00:29, 63.8MB/s]

 23%|██▎       | 395M/1.70G [00:05<00:23, 60.2MB/s]
 18%|█▊        | 394M/2.12G [00:05<00:27, 68.5MB/s]

 23%|██▎       | 402M/1.70G [00:05<00:22, 62.7MB/s]
 19%|█▊        | 401M/2.12G [00:05<00:27, 66.3MB/s]

 23%|██▎       | 409M/1.70G [00:05<00:21, 64.4MB/s]
 19%|█▉        | 408M/2.12G [00:05<00:28, 65.4MB/s]

 24%|██▍       | 416M/1.70G [00:05<00:21, 65.4MB/s]

 24%|██▍       | 423M/1.70G [00:05<00:26, 51.4MB/s]
 19%|█▉        | 415M/2.12G [00:07<03:02, 10.1MB/s]

 25%|██▍       | 429M/1.70G [00:08<02:56, 7.79MB/s]
 19%|█▉        | 420M/2.12G [00:08<03:17, 9.28MB/s]

 25%|██▌       | 439M/1.70G [00:08<01:52, 12.2MB/s]
 20%|█▉        | 432M/2.12G [00:08<01:56, 15.7MB/s]

 26%|██▌       | 447M/1.70G [00:08<01:22, 16.5MB/s]
 20%|██        | 441M/2.12G [00:08<01:25, 21.1MB/s]

 26%|██▌       | 454M/1.70G [00:08<01:05, 20.5MB/s]
 21%|██        | 448M/2.12G [00:09<01:14, 24.0MB/s]

 26%|██▋       | 460M/1.70G [00:08<00:54, 24.4MB/s]

 27%|██▋       | 468M/1.70G [00:09<00:42, 31.6MB/s]
 21%|██        | 454M/2.12G [00:09<01:07, 26.4MB/s]
 21%|██▏       | 464M/2.12G [00:09<00:49, 36.4MB/s]

 27%|██▋       | 476M/1.70G [00:09<00:38, 34.4MB/s]

 28%|██▊       | 487M/1.70G [00:09<00:28, 47.0MB/s]
 22%|██▏       | 471M/2.12G [00:09<00:53, 33.2MB/s]

 28%|██▊       | 495M/1.70G [00:09<00:28, 46.0MB/s]
 22%|██▏       | 477M/2.12G [00:09<00:49, 36.0MB/s]

 29%|██▉       | 502M/1.70G [00:09<00:27, 47.6MB/s]

 29%|██▉       | 510M/1.70G [00:09<00:23, 54.1MB/s]
 22%|██▏       | 483M/2.12G [00:09<00:52, 33.9MB/s]

 30%|██▉       | 519M/1.70G [00:09<00:20, 62.0MB/s]
 23%|██▎       | 491M/2.12G [00:10<00:41, 41.9MB/s]

 30%|███       | 527M/1.70G [00:09<00:19, 65.6MB/s]
 23%|██▎       | 497M/2.12G [00:10<00:47, 37.1MB/s]
 23%|██▎       | 506M/2.12G [00:10<00:36, 47.1MB/s]

 31%|███       | 535M/1.70G [00:10<00:26, 47.1MB/s]
 24%|██▎       | 513M/2.12G [00:10<00:35, 49.5MB/s]

 31%|███       | 541M/1.70G [00:10<00:25, 49.2MB/s]
 24%|██▍       | 523M/2.12G [00:10<00:28, 60.8MB/s]

 32%|███▏      | 551M/1.70G [00:10<00:20, 60.3MB/s]
 24%|██▍       | 530M/2.12G [00:10<00:30, 56.8MB/s]

 32%|███▏      | 561M/1.70G [00:10<00:18, 65.7MB/s]
 25%|██▍       | 537M/2.12G [00:10<00:29, 58.3MB/s]

 33%|███▎      | 569M/1.70G [00:10<00:18, 67.2MB/s]
 25%|██▌       | 547M/2.12G [00:10<00:24, 68.2MB/s]

 33%|███▎      | 578M/1.70G [00:10<00:16, 73.2MB/s]
 26%|██▌       | 557M/2.12G [00:11<00:21, 77.0MB/s]

 34%|███▎      | 586M/1.70G [00:10<00:17, 71.0MB/s]

 34%|███▍      | 595M/1.70G [00:11<00:15, 76.1MB/s]
 26%|██▌       | 565M/2.12G [00:11<00:24, 69.2MB/s]

 35%|███▍      | 609M/1.70G [00:11<00:13, 88.9MB/s]
 26%|██▋       | 573M/2.12G [00:11<00:26, 63.3MB/s]

 35%|███▌      | 618M/1.70G [00:11<00:13, 90.0MB/s]
 27%|██▋       | 583M/2.12G [00:11<00:23, 71.8MB/s]

 36%|███▌      | 630M/1.70G [00:11<00:11, 99.5MB/s]

 37%|███▋      | 640M/1.70G [00:12<00:37, 31.0MB/s]
 27%|██▋       | 591M/2.12G [00:12<01:09, 23.8MB/s]

 37%|███▋      | 650M/1.70G [00:12<00:29, 38.7MB/s]
 28%|██▊       | 602M/2.12G [00:12<00:49, 32.9MB/s]
 28%|██▊       | 611M/2.12G [00:12<00:40, 40.3MB/s]

 38%|███▊      | 660M/1.70G [00:12<00:24, 46.6MB/s]

 39%|███▊      | 673M/1.70G [00:12<00:18, 60.3MB/s]
 29%|██▊       | 619M/2.12G [00:12<00:37, 43.2MB/s]

 39%|███▉      | 684M/1.70G [00:12<00:15, 69.5MB/s]
 29%|██▉       | 626M/2.12G [00:12<00:35, 45.0MB/s]

 40%|████      | 697M/1.70G [00:12<00:13, 79.8MB/s]
 29%|██▉       | 638M/2.12G [00:13<00:27, 59.2MB/s]

 41%|████      | 707M/1.70G [00:12<00:12, 83.7MB/s]
 30%|██▉       | 646M/2.12G [00:13<00:25, 63.4MB/s]

 41%|████      | 717M/1.70G [00:13<00:12, 88.5MB/s]
 30%|███       | 654M/2.12G [00:13<00:23, 67.5MB/s]
 31%|███       | 662M/2.12G [00:13<00:22, 70.8MB/s]

 42%|████▏     | 727M/1.70G [00:13<00:15, 67.7MB/s]

 42%|████▏     | 736M/1.70G [00:13<00:15, 68.8MB/s]
 31%|███       | 670M/2.12G [00:13<00:26, 58.5MB/s]

 43%|████▎     | 744M/1.70G [00:13<00:15, 67.3MB/s]
 31%|███▏      | 677M/2.12G [00:13<00:29, 53.2MB/s]

 43%|████▎     | 755M/1.70G [00:13<00:13, 77.3MB/s]
 32%|███▏      | 683M/2.12G [00:13<00:28, 54.4MB/s]

 44%|████▍     | 765M/1.70G [00:13<00:12, 83.9MB/s]
 32%|███▏      | 690M/2.12G [00:13<00:26, 58.6MB/s]

 44%|████▍     | 774M/1.70G [00:13<00:13, 77.1MB/s]
 32%|███▏      | 703M/2.12G [00:14<00:19, 77.2MB/s]

 45%|████▌     | 786M/1.70G [00:14<00:11, 88.4MB/s]
 33%|███▎      | 712M/2.12G [00:14<00:20, 73.5MB/s]

 46%|████▌     | 797M/1.70G [00:14<00:10, 94.8MB/s]
 33%|███▎      | 722M/2.12G [00:14<00:19, 79.4MB/s]

 46%|████▋     | 807M/1.70G [00:14<00:10, 93.8MB/s]
 34%|███▍      | 731M/2.12G [00:14<00:18, 81.6MB/s]

 47%|████▋     | 817M/1.70G [00:14<00:10, 90.0MB/s]
 34%|███▍      | 740M/2.12G [00:14<00:17, 83.8MB/s]

 48%|████▊     | 829M/1.70G [00:14<00:09, 98.2MB/s]
 35%|███▍      | 749M/2.12G [00:14<00:17, 84.2MB/s]

 48%|████▊     | 839M/1.70G [00:14<00:09, 95.3MB/s]
 35%|███▌      | 759M/2.12G [00:14<00:16, 89.8MB/s]
 36%|███▌      | 769M/2.12G [00:14<00:15, 93.3MB/s]

 49%|████▉     | 849M/1.70G [00:14<00:09, 93.8MB/s]
 36%|███▌      | 780M/2.12G [00:14<00:14, 97.7MB/s]

 49%|████▉     | 859M/1.70G [00:14<00:10, 91.0MB/s]
 37%|███▋      | 793M/2.12G [00:15<00:13, 106MB/s] 

 50%|████▉     | 868M/1.70G [00:14<00:10, 89.9MB/s]
 37%|███▋      | 804M/2.12G [00:15<00:13, 107MB/s]

 50%|█████     | 877M/1.70G [00:15<00:10, 87.0MB/s]

 51%|█████     | 886M/1.70G [00:15<00:10, 85.0MB/s]
 38%|███▊      | 815M/2.12G [00:15<00:16, 84.8MB/s]

 51%|█████▏    | 895M/1.70G [00:15<00:12, 69.5MB/s]
 38%|███▊      | 824M/2.12G [00:15<00:18, 74.1MB/s]

 52%|█████▏    | 904M/1.70G [00:15<00:11, 73.7MB/s]
 38%|███▊      | 832M/2.12G [00:15<00:18, 75.1MB/s]

 52%|█████▏    | 912M/1.70G [00:15<00:11, 75.6MB/s]
 39%|███▉      | 843M/2.12G [00:15<00:16, 83.7MB/s]
 40%|███▉      | 856M/2.12G [00:15<00:14, 95.7MB/s]

 53%|█████▎    | 921M/1.70G [00:15<00:12, 71.1MB/s]
 40%|███▉      | 866M/2.12G [00:15<00:13, 97.5MB/s]

 53%|█████▎    | 931M/1.70G [00:15<00:10, 77.9MB/s]
 41%|████      | 878M/2.12G [00:16<00:12, 104MB/s] 

 54%|█████▍    | 939M/1.70G [00:16<00:11, 70.3MB/s]
 41%|████      | 889M/2.12G [00:16<00:12, 104MB/s]

 55%|█████▍    | 950M/1.70G [00:16<00:10, 80.9MB/s]

 56%|█████▌    | 967M/1.70G [00:16<00:07, 105MB/s] 
 42%|████▏     | 900M/2.12G [00:16<00:17, 73.9MB/s]

 56%|█████▌    | 978M/1.70G [00:16<00:07, 105MB/s]
 42%|████▏     | 909M/2.12G [00:16<00:17, 76.9MB/s]

 57%|█████▋    | 989M/1.70G [00:16<00:07, 103MB/s]
 43%|████▎     | 921M/2.12G [00:16<00:14, 87.9MB/s]

 57%|█████▋    | 0.98G/1.70G [00:16<00:07, 105MB/s]

 58%|█████▊    | 0.99G/1.70G [00:16<00:06, 110MB/s]
 43%|████▎     | 931M/2.12G [00:16<00:15, 81.4MB/s]
 43%|████▎     | 940M/2.12G [00:16<00:15, 82.6MB/s]

 59%|█████▊    | 1.00G/1.70G [00:16<00:09, 76.3MB/s]
 44%|████▍     | 949M/2.12G [00:17<00:18, 70.7MB/s]
 44%|████▍     | 957M/2.12G [00:17<00:18, 67.3MB/s]

 59%|█████▉    | 1.01G/1.70G [00:17<00:11, 64.5MB/s]
 45%|████▍     | 964M/2.12G [00:17<00:19, 65.9MB/s]

 60%|█████▉    | 1.02G/1.70G [00:17<00:11, 64.5MB/s]
 45%|████▍     | 971M/2.12G [00:17<00:19, 65.2MB/s]

 60%|██████    | 1.02G/1.70G [00:17<00:11, 63.6MB/s]
 45%|████▌     | 978M/2.12G [00:17<00:19, 63.6MB/s]

 61%|██████    | 1.03G/1.70G [00:17<00:11, 62.9MB/s]
 45%|████▌     | 985M/2.12G [00:17<00:19, 64.7MB/s]

 61%|██████    | 1.04G/1.70G [00:17<00:11, 62.8MB/s]
 46%|████▌     | 992M/2.12G [00:17<00:18, 65.3MB/s]

 61%|██████▏   | 1.04G/1.70G [00:17<00:10, 64.5MB/s]
 46%|████▌     | 0.98G/2.12G [00:17<00:17, 68.9MB/s]

 62%|██████▏   | 1.05G/1.70G [00:17<00:10, 66.5MB/s]
 46%|████▋     | 0.98G/2.12G [00:18<00:17, 69.8MB/s]

 62%|██████▏   | 1.06G/1.70G [00:17<00:10, 67.7MB/s]
 47%|████▋     | 0.99G/2.12G [00:18<00:16, 73.2MB/s]

 63%|██████▎   | 1.07G/1.70G [00:18<00:09, 69.0MB/s]
 47%|████▋     | 1.00G/2.12G [00:18<00:16, 72.2MB/s]

 63%|██████▎   | 1.07G/1.70G [00:18<00:09, 68.5MB/s]
 48%|████▊     | 1.01G/2.12G [00:18<00:16, 73.8MB/s]

 63%|██████▎   | 1.08G/1.70G [00:18<00:09, 69.1MB/s]
 48%|████▊     | 1.01G/2.12G [00:18<00:18, 63.5MB/s]

 64%|██████▍   | 1.09G/1.70G [00:18<00:10, 61.4MB/s]
 48%|████▊     | 1.02G/2.12G [00:18<00:20, 57.1MB/s]

 64%|██████▍   | 1.09G/1.70G [00:18<00:11, 55.9MB/s]
 49%|████▊     | 1.03G/2.12G [00:18<00:20, 57.9MB/s]

 65%|██████▍   | 1.10G/1.70G [00:18<00:11, 56.5MB/s]
 49%|████▉     | 1.04G/2.12G [00:18<00:18, 63.5MB/s]

 65%|██████▍   | 1.10G/1.70G [00:18<00:11, 57.4MB/s]
 49%|████▉     | 1.04G/2.12G [00:19<00:18, 63.4MB/s]

 65%|██████▌   | 1.11G/1.70G [00:18<00:10, 58.5MB/s]
 50%|████▉     | 1.05G/2.12G [00:19<00:17, 64.7MB/s]

 66%|██████▌   | 1.12G/1.70G [00:19<00:10, 61.8MB/s]
 50%|████▉     | 1.06G/2.12G [00:19<00:16, 67.6MB/s]

 66%|██████▌   | 1.12G/1.70G [00:19<00:09, 63.1MB/s]
 50%|█████     | 1.06G/2.12G [00:19<00:15, 70.6MB/s]

 67%|██████▋   | 1.13G/1.70G [00:19<00:08, 69.1MB/s]
 51%|█████     | 1.07G/2.12G [00:19<00:15, 72.6MB/s]

 67%|██████▋   | 1.14G/1.70G [00:19<00:08, 71.4MB/s]
 51%|█████     | 1.08G/2.12G [00:19<00:15, 73.6MB/s]

 68%|██████▊   | 1.15G/1.70G [00:19<00:07, 74.3MB/s]
 51%|█████▏    | 1.09G/2.12G [00:19<00:14, 76.0MB/s]

 68%|██████▊   | 1.16G/1.70G [00:19<00:07, 74.0MB/s]
 52%|█████▏    | 1.10G/2.12G [00:19<00:15, 69.1MB/s]

 69%|██████▊   | 1.17G/1.70G [00:19<00:08, 69.9MB/s]
 52%|█████▏    | 1.10G/2.12G [00:19<00:15, 70.8MB/s]

 69%|██████▉   | 1.17G/1.70G [00:19<00:08, 70.0MB/s]
 52%|█████▏    | 1.11G/2.12G [00:20<00:16, 66.4MB/s]

 69%|██████▉   | 1.18G/1.70G [00:19<00:09, 60.2MB/s]

 70%|██████▉   | 1.19G/1.70G [00:20<00:08, 62.5MB/s]
 53%|█████▎    | 1.12G/2.12G [00:20<00:19, 54.8MB/s]

 70%|███████   | 1.19G/1.70G [00:20<00:08, 63.8MB/s]
 53%|█████▎    | 1.12G/2.12G [00:20<00:18, 57.7MB/s]

 71%|███████   | 1.20G/1.70G [00:20<00:08, 64.8MB/s]
 53%|█████▎    | 1.13G/2.12G [00:20<00:17, 61.3MB/s]

 71%|███████   | 1.21G/1.70G [00:20<00:07, 68.3MB/s]
 54%|█████▍    | 1.14G/2.12G [00:20<00:16, 65.0MB/s]

 71%|███████▏  | 1.21G/1.70G [00:20<00:07, 66.1MB/s]
 54%|█████▍    | 1.15G/2.12G [00:20<00:16, 64.4MB/s]

 72%|███████▏  | 1.22G/1.70G [00:20<00:07, 66.4MB/s]
 54%|█████▍    | 1.15G/2.12G [00:20<00:16, 63.0MB/s]

 72%|███████▏  | 1.23G/1.70G [00:20<00:08, 61.2MB/s]
 55%|█████▍    | 1.16G/2.12G [00:20<00:16, 61.6MB/s]

 73%|███████▎  | 1.23G/1.70G [00:20<00:08, 58.0MB/s]
 55%|█████▌    | 1.17G/2.12G [00:21<00:16, 61.0MB/s]
 55%|█████▌    | 1.17G/2.12G [00:22<01:11, 14.1MB/s]

 73%|███████▎  | 1.24G/1.70G [00:22<00:37, 13.4MB/s]
 56%|█████▌    | 1.18G/2.12G [00:22<00:51, 19.6MB/s]

 73%|███████▎  | 1.25G/1.70G [00:22<00:23, 20.3MB/s]
 56%|█████▌    | 1.19G/2.12G [00:22<00:42, 23.4MB/s]

 74%|███████▍  | 1.25G/1.70G [00:22<00:19, 24.2MB/s]

 74%|███████▍  | 1.26G/1.70G [00:22<00:13, 34.0MB/s]

 75%|███████▌  | 1.28G/1.70G [00:22<00:08, 50.6MB/s]

 76%|███████▌  | 1.29G/1.70G [00:22<00:07, 59.8MB/s]
 57%|█████▋    | 1.20G/2.12G [00:23<00:35, 27.6MB/s]
 57%|█████▋    | 1.21G/2.12G [00:23<00:26, 36.3MB/s]

 76%|███████▋  | 1.30G/1.70G [00:23<00:07, 57.6MB/s]

 77%|███████▋  | 1.31G/1.70G [00:23<00:06, 62.4MB/s]
 57%|█████▋    | 1.21G/2.12G [00:23<00:27, 35.1MB/s]
 58%|█████▊    | 1.22G/2.12G [00:23<00:24, 38.9MB/s]

 77%|███████▋  | 1.31G/1.70G [00:23<00:07, 53.2MB/s]
 58%|█████▊    | 1.23G/2.12G [00:23<00:22, 43.0MB/s]

 78%|███████▊  | 1.32G/1.70G [00:23<00:06, 61.9MB/s]

 78%|███████▊  | 1.33G/1.70G [00:23<00:05, 68.7MB/s]
 58%|█████▊    | 1.24G/2.12G [00:23<00:20, 46.8MB/s]

 79%|███████▉  | 1.34G/1.70G [00:23<00:05, 72.2MB/s]
 59%|█████▉    | 1.24G/2.12G [00:23<00:20, 46.0MB/s]

 80%|███████▉  | 1.35G/1.70G [00:23<00:04, 77.2MB/s]

 80%|███████▉  | 1.36G/1.70G [00:23<00:04, 74.6MB/s]
 59%|█████▉    | 1.25G/2.12G [00:24<00:19, 47.6MB/s]

 81%|████████  | 1.37G/1.70G [00:24<00:04, 79.9MB/s]
 60%|█████▉    | 1.26G/2.12G [00:24<00:16, 56.2MB/s]

 81%|████████  | 1.38G/1.70G [00:24<00:03, 88.6MB/s]
 60%|█████▉    | 1.27G/2.12G [00:24<00:18, 48.6MB/s]

 82%|████████▏ | 1.39G/1.70G [00:24<00:03, 87.9MB/s]
 60%|██████    | 1.27G/2.12G [00:24<00:16, 55.1MB/s]

 82%|████████▏ | 1.40G/1.70G [00:24<00:03, 81.4MB/s]
 61%|██████    | 1.28G/2.12G [00:24<00:14, 62.4MB/s]

 83%|████████▎ | 1.41G/1.70G [00:24<00:03, 82.2MB/s]
 61%|██████    | 1.29G/2.12G [00:24<00:14, 62.8MB/s]

 83%|████████▎ | 1.42G/1.70G [00:24<00:03, 85.9MB/s]
 61%|██████▏   | 1.30G/2.12G [00:24<00:13, 64.0MB/s]

 84%|████████▍ | 1.43G/1.70G [00:24<00:02, 99.6MB/s]

 85%|████████▍ | 1.45G/1.70G [00:24<00:02, 109MB/s] 
 62%|██████▏   | 1.31G/2.12G [00:25<00:15, 57.2MB/s]

 86%|████████▌ | 1.46G/1.70G [00:25<00:02, 110MB/s]
 62%|██████▏   | 1.32G/2.12G [00:25<00:12, 69.9MB/s]

 86%|████████▌ | 1.47G/1.70G [00:25<00:02, 110MB/s]
 63%|██████▎   | 1.33G/2.12G [00:25<00:11, 74.1MB/s]

 87%|████████▋ | 1.48G/1.70G [00:25<00:02, 102MB/s]
 63%|██████▎   | 1.34G/2.12G [00:25<00:10, 82.6MB/s]

 87%|████████▋ | 1.49G/1.70G [00:25<00:02, 96.1MB/s]
 64%|██████▎   | 1.34G/2.12G [00:25<00:09, 85.4MB/s]

 88%|████████▊ | 1.50G/1.70G [00:25<00:02, 106MB/s] 
 64%|██████▍   | 1.35G/2.12G [00:25<00:11, 71.7MB/s]

 89%|████████▉ | 1.52G/1.70G [00:25<00:01, 119MB/s]
 65%|██████▍   | 1.37G/2.12G [00:25<00:09, 84.5MB/s]

 90%|████████▉ | 1.53G/1.70G [00:25<00:01, 117MB/s]
 65%|██████▍   | 1.37G/2.12G [00:25<00:11, 69.7MB/s]

 91%|█████████ | 1.54G/1.70G [00:25<00:01, 99.4MB/s]
 65%|██████▌   | 1.38G/2.12G [00:26<00:10, 74.7MB/s]

 91%|█████████ | 1.55G/1.70G [00:25<00:01, 99.8MB/s]
 66%|██████▌   | 1.39G/2.12G [00:26<00:10, 74.7MB/s]

 92%|█████████▏| 1.56G/1.70G [00:26<00:01, 102MB/s] 
 66%|██████▋   | 1.40G/2.12G [00:26<00:09, 83.9MB/s]

 93%|█████████▎| 1.57G/1.70G [00:26<00:01, 108MB/s]
 67%|██████▋   | 1.41G/2.12G [00:26<00:08, 91.4MB/s]

 93%|█████████▎| 1.58G/1.70G [00:26<00:01, 100MB/s]
 67%|██████▋   | 1.42G/2.12G [00:26<00:08, 90.9MB/s]

 94%|█████████▎| 1.59G/1.70G [00:26<00:01, 87.1MB/s]
 68%|██████▊   | 1.43G/2.12G [00:26<00:08, 86.6MB/s]

 94%|█████████▍| 1.60G/1.70G [00:26<00:01, 82.8MB/s]
 68%|██████▊   | 1.44G/2.12G [00:26<00:10, 70.8MB/s]

 95%|█████████▍| 1.61G/1.70G [00:26<00:01, 78.2MB/s]
 68%|██████▊   | 1.45G/2.12G [00:26<00:09, 73.6MB/s]

 95%|█████████▌| 1.62G/1.70G [00:26<00:01, 83.3MB/s]
 69%|██████▉   | 1.46G/2.12G [00:27<00:08, 80.8MB/s]

 96%|█████████▌| 1.63G/1.70G [00:26<00:00, 85.3MB/s]
 69%|██████▉   | 1.47G/2.12G [00:27<00:08, 81.4MB/s]

 96%|█████████▋| 1.64G/1.70G [00:27<00:00, 87.5MB/s]
 70%|██████▉   | 1.48G/2.12G [00:27<00:07, 86.9MB/s]

 97%|█████████▋| 1.65G/1.70G [00:27<00:00, 70.2MB/s]
 70%|███████   | 1.48G/2.12G [00:27<00:10, 67.4MB/s]

 97%|█████████▋| 1.66G/1.70G [00:27<00:00, 75.5MB/s]
 71%|███████   | 1.49G/2.12G [00:27<00:09, 71.1MB/s]

 98%|█████████▊| 1.67G/1.70G [00:27<00:00, 82.3MB/s]
 71%|███████   | 1.50G/2.12G [00:27<00:08, 74.9MB/s]

 99%|█████████▊| 1.68G/1.70G [00:27<00:00, 90.1MB/s]
 71%|███████▏  | 1.51G/2.12G [00:27<00:08, 75.7MB/s]

 99%|█████████▉| 1.69G/1.70G [00:27<00:00, 89.7MB/s]
 72%|███████▏  | 1.52G/2.12G [00:27<00:08, 79.7MB/s]

100%|██████████| 1.70G/1.70G [00:27<00:00, 65.5MB/s]

 72%|███████▏  | 1.53G/2.12G [00:28<00:08, 76.6MB/s]
 73%|███████▎  | 1.54G/2.12G [00:28<00:06, 98.0MB/s]
 74%|███████▎  | 1.56G/2.12G [00:28<00:05, 112MB/s] 
 74%|███████▍  | 1.57G/2.12G [00:28<00:05, 105MB/s]
 75%|███████▍  | 1.58G/2.12G [00:28<00:05, 107MB/s]
 75%|███████▌  | 1.59G/2.12G [00:28<00:06, 89.3MB/s]
 76%|███████▌  | 1.61G/2.12G [00:28<00:05, 101MB/s] 
 77%|███████▋  | 1.62G/2.12G [00:28<00:04, 112MB/s]
 77%|███████▋  | 1.63G/2.12G [00:29<00:05, 92.3MB/s]
 78%|███████▊  | 1.64G/2.12G [00:29<00:06, 78.4MB/s]
 78%|███████▊  | 1.66G/2.12G [00:29<00:05, 95.5MB/s]
 79%|███████▉  | 1.67G/2.12G [00:29<00:04, 101MB/s] 
 80%|███████▉  | 1.69G/2.12G [00:29<00:04, 114MB/s]
 80%|████████  | 1.70G/2.12G [00:29<00:03, 124MB/s]
 81%|████████  | 1.71G/2.12G [00:29<00:03, 122MB/s]
 82%|████████▏ | 1.73G/2.12G [00:30<00:04, 95.5MB/s]
 82%|████████▏ | 1.74G/2.12G [00:30<00:03, 104MB/s] 
 83%|████████▎ | 1.75G/2.12G [00:30<00:03, 103MB/s]
 83%|████████▎ | 1.76G/2.12G [00:30<00:04, 76.8MB/s]
 84%|████████▍ | 1.78G/2.12G [00:30<00:03, 96.8MB/s]
 85%|████████▍ | 1.79G/2.12G [00:30<00:03, 102MB/s] 
 85%|████████▌ | 1.80G/2.12G [00:31<00:04, 69.5MB/s]
 86%|████████▌ | 1.82G/2.12G [00:31<00:03, 92.2MB/s]
 87%|████████▋ | 1.83G/2.12G [00:31<00:03, 91.5MB/s]
 87%|████████▋ | 1.84G/2.12G [00:31<00:02, 98.1MB/s]
 88%|████████▊ | 1.85G/2.12G [00:31<00:03, 85.1MB/s]
 88%|████████▊ | 1.86G/2.12G [00:31<00:03, 88.0MB/s]
 89%|████████▊ | 1.87G/2.12G [00:31<00:02, 88.9MB/s]
 89%|████████▉ | 1.88G/2.12G [00:32<00:02, 93.5MB/s]
 90%|████████▉ | 1.90G/2.12G [00:32<00:02, 106MB/s] 
 90%|█████████ | 1.91G/2.12G [00:32<00:01, 111MB/s]
 91%|█████████ | 1.92G/2.12G [00:32<00:02, 98.7MB/s]
 91%|█████████▏| 1.93G/2.12G [00:32<00:01, 107MB/s] 
 92%|█████████▏| 1.95G/2.12G [00:32<00:01, 104MB/s]
 93%|█████████▎| 1.96G/2.12G [00:32<00:01, 117MB/s]
 93%|█████████▎| 1.97G/2.12G [00:32<00:01, 106MB/s]
 94%|█████████▍| 1.98G/2.12G [00:33<00:01, 92.6MB/s]
 94%|█████████▍| 1.99G/2.12G [00:33<00:01, 86.4MB/s]
 95%|█████████▍| 2.00G/2.12G [00:33<00:01, 68.3MB/s]
 95%|█████████▌| 2.02G/2.12G [00:33<00:01, 84.0MB/s]
 96%|█████████▌| 2.03G/2.12G [00:33<00:01, 91.6MB/s]
 96%|█████████▋| 2.04G/2.12G [00:33<00:00, 96.2MB/s]
 97%|█████████▋| 2.05G/2.12G [00:33<00:00, 108MB/s] 
 98%|█████████▊| 2.06G/2.12G [00:33<00:00, 89.9MB/s]
 98%|█████████▊| 2.08G/2.12G [00:34<00:00, 103MB/s] 
 99%|█████████▉| 2.09G/2.12G [00:34<00:00, 115MB/s]
100%|██████████| 2.12G/2.12G [00:34<00:00, 66.0MB/s]
</ph>
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1
  1. 모델 가중치와 tokenizer의 위치를 확인한 다음 경로 변수를 설정합니다. tokenizer 디렉터리는 모델을 다운로드한 기본 디렉터리에 있고 모델 가중치는 하위 디렉터리에 있습니다. 예를 들면 다음과 같습니다.
  • tokenizer.model 파일은 /LOCAL/PATH/TO/gemma/flax/2b-it/2에 있습니다.
  • 모델 체크포인트는 /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it에 있습니다.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1/gemma2-2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1/tokenizer.model

샘플링/추론 수행

  1. gemma.params.load_and_format_params 메서드를 사용하여 Gemma 모델 체크포인트를 로드하고 형식을 지정합니다.
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. sentencepiece.SentencePieceProcessor를 사용하여 구성된 Gemma tokenizer를 로드합니다.
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Gemma 모델 체크포인트에서 올바른 구성을 자동으로 로드하려면 gemma.transformer.TransformerConfig를 사용하세요. cache_size 인수는 Gemma Transformer 캐시의 시간 단계 수입니다. 그런 다음 flax.linen.Module에서 상속되는 gemma.transformer.Transformer를 사용하여 Gemma 모델을 transformer로 인스턴스화합니다.
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Gemma 모델 체크포인트/가중치 및 tokenizer 위에 gemma.sampler.Sampler를 사용하여 sampler를 만듭니다.
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. input_batch로 프롬프트를 작성하고 추론을 수행합니다. total_generation_steps (응답을 생성할 때 수행되는 단계 수)를 조정할 수 있습니다. 이 예에서는 100를 사용하여 호스트 메모리를 보존합니다.
prompt = [
    "what is JAX in 3 bullet points?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=128,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:
what is JAX in 3 bullet points?
Output:


* **High-performance numerical computation:** JAX leverages the power of GPUs and TPUs to accelerate complex mathematical operations, making it ideal for scientific computing, machine learning, and data analysis.
* **Automatic differentiation:** JAX provides automatic differentiation capabilities, allowing you to compute gradients and optimize models efficiently. This simplifies the process of training deep learning models.
* **Functional programming:** JAX embraces functional programming principles, promoting code readability and maintainability. It offers a flexible and expressive syntax for defining and manipulating data. 


<end_of_turn>
  1. (선택사항) 노트북 작업을 완료하고 다른 프롬프트를 시도하려면 이 셀을 실행하여 메모리를 확보합니다. 그런 다음 3단계에서 sampler를 다시 인스턴스화하고 4단계에서 프롬프트를 맞춤설정하여 실행할 수 있습니다.
del sampler

자세히 알아보기