Inferência com o Gemma usando JAX e Flax

Ver em ai.google.dev Executar no Google Colab Abrir na Vertex AI Veja o código-fonte no GitHub

Visão geral

O Gemma é uma família de modelos de linguagem grandes, leves e de última geração abertos, baseados na pesquisa e tecnologia do Google DeepMind Gemini. Este tutorial demonstra como realizar uma amostragem/inferência básica com o modelo de instrução Gemma 2B usando a biblioteca gemma do Google DeepMind, escrita com JAX (uma biblioteca de computação numérica de alto desempenho), Flax (a biblioteca de rede neural baseada em JAX), Orbax (uma biblioteca baseada em JAX para utilitários de treinamento, como tokenizador) e SentencePiece Embora o Flax não seja usado diretamente neste notebook, ele foi usado para criar o Gemma.

Este notebook pode ser executado no Google Colab com GPU T4 sem custo financeiro. Acesse Editar > Configurações do notebook > Em Acelerador de hardware, selecione GPU T4.

Configuração

1. Configure o acesso do Kaggle para o Gemma

Para concluir este tutorial, primeiro siga as instruções de configuração em Configuração do Gemma, que mostram como fazer o seguinte:

  • Acesse o Gemma em kaggle.com.
  • Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo Gemma.
  • Gere e configure um nome de usuário e uma chave de API do Kaggle.

Depois de concluir a configuração do Gemma, vá para a próxima seção, em que você definirá variáveis de ambiente para o ambiente do Colab.

2. Defina as variáveis de ambiente

Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY. Quando a pergunta "Conceder acesso?" for exibida concordam em fornecer acesso a secrets.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Instalar a biblioteca gemma

Este notebook se concentra no uso de uma GPU Colab sem custo financeiro. Para ativar a aceleração de hardware, clique em Editar > Configurações do notebook > Selecione GPU T4 > Clique em Salvar.

Em seguida, você precisa instalar a biblioteca gemma do Google DeepMind no github.com/google-deepmind/gemma. Se você receber um erro sobre "resolvedor de dependências do pip", geralmente ele poderá ser ignorado.

pip install -q git+https://github.com/google-deepmind/gemma.git

Carregar e preparar o modelo do Gemma

  1. Carregue o modelo do Gemma com kagglehub.model_download, que usa três argumentos:
  • handle: o identificador do modelo do Kaggle
  • path: (string opcional) o caminho local
  • force_download (booleano opcional): força o download do modelo novamente.
.
GEMMA_VARIANT = 'gemma2-2b-it' # @param ['gemma2-2b', 'gemma2-2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma-2/flax/{GEMMA_VARIANT}')
Downloading 11 files:   0%|          | 0/11 [00:00<?, ?it/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/manifest.ocdbt...
100%|██████████| 180/180 [00:00<00:00, 101kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/d/b5a4695f4be0a2f41ec1e25616ebd7e7...
100%|██████████| 2.66k/2.66k [00:00<00:00, 5.36MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/descriptor/descriptor.pbtxt...
100%|██████████| 45.0/45.0 [00:00<00:00, 90.0kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/_METADATA...
100%|██████████| 55.3k/55.3k [00:00<00:00, 29.5MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/_CHECKPOINT_METADATA...
100%|██████████| 92.0/92.0 [00:00<00:00, 234kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/bf69258061ae5f35eb7a5669fe6877d4...
0%|          | 0.00/2.12G [00:00<?, ?B/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/fc20151969d7ca91ea9d8275bda0e219...
100%|██████████| 2.64k/2.64k [00:00<00:00, 5.58MB/s]

  0%|          | 2.00M/2.12G [00:00<01:48, 20.8MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/834bb4bf1e3854eb09f6208c95c071b2...
0%|          | 0.00/1.70G [00:00<?, ?B/s]
  0%|          | 9.00M/2.12G [00:00<00:46, 48.2MB/s]

  0%|          | 3.00M/1.70G [00:00<01:06, 27.6MB/s]
  1%|          | 14.0M/2.12G [00:00<00:46, 48.6MB/s]

  1%|          | 9.00M/1.70G [00:00<00:40, 44.5MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/manifest.ocdbt...
100%|██████████| 118/118 [00:00<00:00, 303kB/s]

  1%|          | 21.0M/2.12G [00:00<00:41, 53.7MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/checkpoint...
0%|          | 0.00/22.5k [00:00<?, ?B/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/tokenizer.model...
100%|██████████| 22.5k/22.5k [00:00<00:00, 24.7MB/s]


  1%|          | 17.0M/1.70G [00:00<00:36, 49.5MB/s]


  0%|          | 0.00/4.04M [00:00<?, ?B/s]
100%|██████████| 4.04M/4.04M [00:00<00:00, 64.6MB/s]


  1%|▏         | 24.0M/1.70G [00:00<00:34, 52.7MB/s]
  2%|▏         | 40.0M/2.12G [00:00<00:34, 64.6MB/s]

  2%|▏         | 33.0M/1.70G [00:00<00:27, 64.4MB/s]
  2%|▏         | 49.0M/2.12G [00:00<00:34, 64.8MB/s]

  3%|▎         | 47.0M/1.70G [00:00<00:20, 86.9MB/s]
  3%|▎         | 59.0M/2.12G [00:00<00:29, 74.4MB/s]

  3%|▎         | 56.0M/1.70G [00:00<00:24, 73.1MB/s]
  3%|▎         | 67.0M/2.12G [00:01<00:31, 70.1MB/s]

  4%|▎         | 64.0M/1.70G [00:01<00:25, 69.4MB/s]
  3%|▎         | 74.0M/2.12G [00:01<00:32, 67.4MB/s]

  4%|▍         | 73.0M/1.70G [00:01<00:23, 75.7MB/s]
  4%|▍         | 84.0M/2.12G [00:01<00:28, 75.5MB/s]

  5%|▍         | 81.0M/1.70G [00:01<00:22, 77.7MB/s]

  5%|▌         | 95.0M/1.70G [00:01<00:17, 96.5MB/s]
  4%|▍         | 92.0M/2.12G [00:01<00:38, 56.7MB/s]

  6%|▌         | 106M/1.70G [00:01<00:17, 101MB/s]  
  5%|▍         | 102M/2.12G [00:01<00:32, 67.2MB/s] 

  7%|▋         | 117M/1.70G [00:01<00:16, 102MB/s]
  5%|▌         | 110M/2.12G [00:01<00:30, 70.5MB/s]

  7%|▋         | 128M/1.70G [00:01<00:16, 105MB/s]
  5%|▌         | 119M/2.12G [00:01<00:28, 75.6MB/s]

  8%|▊         | 142M/1.70G [00:01<00:14, 117MB/s]
  6%|▌         | 129M/2.12G [00:02<00:30, 70.1MB/s]

  9%|▉         | 154M/1.70G [00:01<00:17, 92.5MB/s]
  6%|▋         | 138M/2.12G [00:02<00:28, 73.9MB/s]

  9%|▉         | 164M/1.70G [00:02<00:18, 87.7MB/s]
  7%|▋         | 146M/2.12G [00:02<00:30, 70.2MB/s]

 10%|▉         | 173M/1.70G [00:02<00:20, 81.7MB/s]
  7%|▋         | 153M/2.12G [00:02<00:33, 63.0MB/s]

 10%|█         | 182M/1.70G [00:02<00:19, 82.8MB/s]
  8%|▊         | 164M/2.12G [00:02<00:27, 75.3MB/s]

 11%|█         | 195M/1.70G [00:02<00:17, 90.8MB/s]
  8%|▊         | 174M/2.12G [00:02<00:25, 82.0MB/s]

 12%|█▏        | 207M/1.70G [00:02<00:16, 99.0MB/s]
  9%|▊         | 186M/2.12G [00:02<00:22, 92.9MB/s]

 13%|█▎        | 218M/1.70G [00:02<00:16, 99.7MB/s]
  9%|▉         | 196M/2.12G [00:02<00:22, 92.4MB/s]

 13%|█▎        | 229M/1.70G [00:02<00:15, 103MB/s] 
 10%|▉         | 206M/2.12G [00:02<00:22, 92.1MB/s]

 14%|█▎        | 239M/1.70G [00:02<00:15, 99.4MB/s]
 10%|▉         | 215M/2.12G [00:03<00:22, 91.3MB/s]

 14%|█▍        | 250M/1.70G [00:03<00:15, 101MB/s] 
 10%|█         | 226M/2.12G [00:03<00:21, 96.5MB/s]

 15%|█▌        | 263M/1.70G [00:03<00:14, 108MB/s]
 11%|█         | 238M/2.12G [00:03<00:19, 105MB/s] 
 11%|█▏        | 249M/2.12G [00:03<00:19, 103MB/s]

 16%|█▌        | 274M/1.70G [00:03<00:16, 91.6MB/s]
 12%|█▏        | 259M/2.12G [00:03<00:21, 93.3MB/s]

 16%|█▋        | 284M/1.70G [00:03<00:20, 76.0MB/s]
 12%|█▏        | 269M/2.12G [00:03<00:21, 94.3MB/s]

 17%|█▋        | 295M/1.70G [00:03<00:18, 84.0MB/s]
 13%|█▎        | 279M/2.12G [00:03<00:20, 94.2MB/s]

 17%|█▋        | 304M/1.70G [00:03<00:17, 84.4MB/s]
 13%|█▎        | 289M/2.12G [00:03<00:20, 94.9MB/s]

 18%|█▊        | 313M/1.70G [00:03<00:18, 81.9MB/s]
 14%|█▍        | 299M/2.12G [00:03<00:21, 91.4MB/s]
 14%|█▍        | 308M/2.12G [00:04<00:21, 89.4MB/s]

 18%|█▊        | 322M/1.70G [00:03<00:20, 73.5MB/s]

 19%|█▉        | 330M/1.70G [00:04<00:19, 74.9MB/s]
 15%|█▍        | 317M/2.12G [00:04<00:23, 81.1MB/s]
 15%|█▌        | 326M/2.12G [00:04<00:23, 83.6MB/s]

 19%|█▉        | 338M/1.70G [00:04<00:20, 72.0MB/s]

 20%|█▉        | 346M/1.70G [00:04<00:19, 74.6MB/s]
 15%|█▌        | 335M/2.12G [00:04<00:24, 79.2MB/s]

 20%|██        | 354M/1.70G [00:04<00:19, 75.0MB/s]
 16%|█▌        | 344M/2.12G [00:04<00:23, 81.4MB/s]
 16%|█▋        | 352M/2.12G [00:04<00:28, 67.3MB/s]

 21%|██        | 362M/1.70G [00:04<00:26, 54.3MB/s]
 17%|█▋        | 359M/2.12G [00:04<00:31, 59.6MB/s]

 21%|██        | 369M/1.70G [00:04<00:26, 53.4MB/s]
 17%|█▋        | 366M/2.12G [00:05<00:31, 59.0MB/s]

 22%|██▏       | 375M/1.70G [00:04<00:26, 54.9MB/s]
 17%|█▋        | 372M/2.12G [00:05<00:31, 59.2MB/s]

 22%|██▏       | 381M/1.70G [00:05<00:25, 56.2MB/s]
 17%|█▋        | 379M/2.12G [00:05<00:30, 62.3MB/s]

 22%|██▏       | 388M/1.70G [00:05<00:24, 56.8MB/s]
 18%|█▊        | 386M/2.12G [00:05<00:29, 63.8MB/s]

 23%|██▎       | 395M/1.70G [00:05<00:23, 60.2MB/s]
 18%|█▊        | 394M/2.12G [00:05<00:27, 68.5MB/s]

 23%|██▎       | 402M/1.70G [00:05<00:22, 62.7MB/s]
 19%|█▊        | 401M/2.12G [00:05<00:27, 66.3MB/s]

 23%|██▎       | 409M/1.70G [00:05<00:21, 64.4MB/s]
 19%|█▉        | 408M/2.12G [00:05<00:28, 65.4MB/s]

 24%|██▍       | 416M/1.70G [00:05<00:21, 65.4MB/s]

 24%|██▍       | 423M/1.70G [00:05<00:26, 51.4MB/s]
 19%|█▉        | 415M/2.12G [00:07<03:02, 10.1MB/s]

 25%|██▍       | 429M/1.70G [00:08<02:56, 7.79MB/s]
 19%|█▉        | 420M/2.12G [00:08<03:17, 9.28MB/s]

 25%|██▌       | 439M/1.70G [00:08<01:52, 12.2MB/s]
 20%|█▉        | 432M/2.12G [00:08<01:56, 15.7MB/s]

 26%|██▌       | 447M/1.70G [00:08<01:22, 16.5MB/s]
 20%|██        | 441M/2.12G [00:08<01:25, 21.1MB/s]

 26%|██▌       | 454M/1.70G [00:08<01:05, 20.5MB/s]
 21%|██        | 448M/2.12G [00:09<01:14, 24.0MB/s]

 26%|██▋       | 460M/1.70G [00:08<00:54, 24.4MB/s]

 27%|██▋       | 468M/1.70G [00:09<00:42, 31.6MB/s]
 21%|██        | 454M/2.12G [00:09<01:07, 26.4MB/s]
 21%|██▏       | 464M/2.12G [00:09<00:49, 36.4MB/s]

 27%|██▋       | 476M/1.70G [00:09<00:38, 34.4MB/s]

 28%|██▊       | 487M/1.70G [00:09<00:28, 47.0MB/s]
 22%|██▏       | 471M/2.12G [00:09<00:53, 33.2MB/s]

 28%|██▊       | 495M/1.70G [00:09<00:28, 46.0MB/s]
 22%|██▏       | 477M/2.12G [00:09<00:49, 36.0MB/s]

 29%|██▉       | 502M/1.70G [00:09<00:27, 47.6MB/s]

 29%|██▉       | 510M/1.70G [00:09<00:23, 54.1MB/s]
 22%|██▏       | 483M/2.12G [00:09<00:52, 33.9MB/s]

 30%|██▉       | 519M/1.70G [00:09<00:20, 62.0MB/s]
 23%|██▎       | 491M/2.12G [00:10<00:41, 41.9MB/s]

 30%|███       | 527M/1.70G [00:09<00:19, 65.6MB/s]
 23%|██▎       | 497M/2.12G [00:10<00:47, 37.1MB/s]
 23%|██▎       | 506M/2.12G [00:10<00:36, 47.1MB/s]

 31%|███       | 535M/1.70G [00:10<00:26, 47.1MB/s]
 24%|██▎       | 513M/2.12G [00:10<00:35, 49.5MB/s]

 31%|███       | 541M/1.70G [00:10<00:25, 49.2MB/s]
 24%|██▍       | 523M/2.12G [00:10<00:28, 60.8MB/s]

 32%|███▏      | 551M/1.70G [00:10<00:20, 60.3MB/s]
 24%|██▍       | 530M/2.12G [00:10<00:30, 56.8MB/s]

 32%|███▏      | 561M/1.70G [00:10<00:18, 65.7MB/s]
 25%|██▍       | 537M/2.12G [00:10<00:29, 58.3MB/s]

 33%|███▎      | 569M/1.70G [00:10<00:18, 67.2MB/s]
 25%|██▌       | 547M/2.12G [00:10<00:24, 68.2MB/s]

 33%|███▎      | 578M/1.70G [00:10<00:16, 73.2MB/s]
 26%|██▌       | 557M/2.12G [00:11<00:21, 77.0MB/s]

 34%|███▎      | 586M/1.70G [00:10<00:17, 71.0MB/s]

 34%|███▍      | 595M/1.70G [00:11<00:15, 76.1MB/s]
 26%|██▌       | 565M/2.12G [00:11<00:24, 69.2MB/s]

 35%|███▍      | 609M/1.70G [00:11<00:13, 88.9MB/s]
 26%|██▋       | 573M/2.12G [00:11<00:26, 63.3MB/s]

 35%|███▌      | 618M/1.70G [00:11<00:13, 90.0MB/s]
 27%|██▋       | 583M/2.12G [00:11<00:23, 71.8MB/s]

 36%|███▌      | 630M/1.70G [00:11<00:11, 99.5MB/s]

 37%|███▋      | 640M/1.70G [00:12<00:37, 31.0MB/s]
 27%|██▋       | 591M/2.12G [00:12<01:09, 23.8MB/s]

 37%|███▋      | 650M/1.70G [00:12<00:29, 38.7MB/s]
 28%|██▊       | 602M/2.12G [00:12<00:49, 32.9MB/s]
 28%|██▊       | 611M/2.12G [00:12<00:40, 40.3MB/s]

 38%|███▊      | 660M/1.70G [00:12<00:24, 46.6MB/s]

 39%|███▊      | 673M/1.70G [00:12<00:18, 60.3MB/s]
 29%|██▊       | 619M/2.12G [00:12<00:37, 43.2MB/s]

 39%|███▉      | 684M/1.70G [00:12<00:15, 69.5MB/s]
 29%|██▉       | 626M/2.12G [00:12<00:35, 45.0MB/s]

 40%|████      | 697M/1.70G [00:12<00:13, 79.8MB/s]
 29%|██▉       | 638M/2.12G [00:13<00:27, 59.2MB/s]

 41%|████      | 707M/1.70G [00:12<00:12, 83.7MB/s]
 30%|██▉       | 646M/2.12G [00:13<00:25, 63.4MB/s]

 41%|████      | 717M/1.70G [00:13<00:12, 88.5MB/s]
 30%|███       | 654M/2.12G [00:13<00:23, 67.5MB/s]
 31%|███       | 662M/2.12G [00:13<00:22, 70.8MB/s]

 42%|████▏     | 727M/1.70G [00:13<00:15, 67.7MB/s]

 42%|████▏     | 736M/1.70G [00:13<00:15, 68.8MB/s]
 31%|███       | 670M/2.12G [00:13<00:26, 58.5MB/s]

 43%|████▎     | 744M/1.70G [00:13<00:15, 67.3MB/s]
 31%|███▏      | 677M/2.12G [00:13<00:29, 53.2MB/s]

 43%|████▎     | 755M/1.70G [00:13<00:13, 77.3MB/s]
 32%|███▏      | 683M/2.12G [00:13<00:28, 54.4MB/s]

 44%|████▍     | 765M/1.70G [00:13<00:12, 83.9MB/s]
 32%|███▏      | 690M/2.12G [00:13<00:26, 58.6MB/s]

 44%|████▍     | 774M/1.70G [00:13<00:13, 77.1MB/s]
 32%|███▏      | 703M/2.12G [00:14<00:19, 77.2MB/s]

 45%|████▌     | 786M/1.70G [00:14<00:11, 88.4MB/s]
 33%|███▎      | 712M/2.12G [00:14<00:20, 73.5MB/s]

 46%|████▌     | 797M/1.70G [00:14<00:10, 94.8MB/s]
 33%|███▎      | 722M/2.12G [00:14<00:19, 79.4MB/s]

 46%|████▋     | 807M/1.70G [00:14<00:10, 93.8MB/s]
 34%|███▍      | 731M/2.12G [00:14<00:18, 81.6MB/s]

 47%|████▋     | 817M/1.70G [00:14<00:10, 90.0MB/s]
 34%|███▍      | 740M/2.12G [00:14<00:17, 83.8MB/s]

 48%|████▊     | 829M/1.70G [00:14<00:09, 98.2MB/s]
 35%|███▍      | 749M/2.12G [00:14<00:17, 84.2MB/s]

 48%|████▊     | 839M/1.70G [00:14<00:09, 95.3MB/s]
 35%|███▌      | 759M/2.12G [00:14<00:16, 89.8MB/s]
 36%|███▌      | 769M/2.12G [00:14<00:15, 93.3MB/s]

 49%|████▉     | 849M/1.70G [00:14<00:09, 93.8MB/s]
 36%|███▌      | 780M/2.12G [00:14<00:14, 97.7MB/s]

 49%|████▉     | 859M/1.70G [00:14<00:10, 91.0MB/s]
 37%|███▋      | 793M/2.12G [00:15<00:13, 106MB/s] 

 50%|████▉     | 868M/1.70G [00:14<00:10, 89.9MB/s]
 37%|███▋      | 804M/2.12G [00:15<00:13, 107MB/s]

 50%|█████     | 877M/1.70G [00:15<00:10, 87.0MB/s]

 51%|█████     | 886M/1.70G [00:15<00:10, 85.0MB/s]
 38%|███▊      | 815M/2.12G [00:15<00:16, 84.8MB/s]

 51%|█████▏    | 895M/1.70G [00:15<00:12, 69.5MB/s]
 38%|███▊      | 824M/2.12G [00:15<00:18, 74.1MB/s]

 52%|█████▏    | 904M/1.70G [00:15<00:11, 73.7MB/s]
 38%|███▊      | 832M/2.12G [00:15<00:18, 75.1MB/s]

 52%|█████▏    | 912M/1.70G [00:15<00:11, 75.6MB/s]
 39%|███▉      | 843M/2.12G [00:15<00:16, 83.7MB/s]
 40%|███▉      | 856M/2.12G [00:15<00:14, 95.7MB/s]

 53%|█████▎    | 921M/1.70G [00:15<00:12, 71.1MB/s]
 40%|███▉      | 866M/2.12G [00:15<00:13, 97.5MB/s]

 53%|█████▎    | 931M/1.70G [00:15<00:10, 77.9MB/s]
 41%|████      | 878M/2.12G [00:16<00:12, 104MB/s] 

 54%|█████▍    | 939M/1.70G [00:16<00:11, 70.3MB/s]
 41%|████      | 889M/2.12G [00:16<00:12, 104MB/s]

 55%|█████▍    | 950M/1.70G [00:16<00:10, 80.9MB/s]

 56%|█████▌    | 967M/1.70G [00:16<00:07, 105MB/s] 
 42%|████▏     | 900M/2.12G [00:16<00:17, 73.9MB/s]

 56%|█████▌    | 978M/1.70G [00:16<00:07, 105MB/s]
 42%|████▏     | 909M/2.12G [00:16<00:17, 76.9MB/s]

 57%|█████▋    | 989M/1.70G [00:16<00:07, 103MB/s]
 43%|████▎     | 921M/2.12G [00:16<00:14, 87.9MB/s]

 57%|█████▋    | 0.98G/1.70G [00:16<00:07, 105MB/s]

 58%|█████▊    | 0.99G/1.70G [00:16<00:06, 110MB/s]
 43%|████▎     | 931M/2.12G [00:16<00:15, 81.4MB/s]
 43%|████▎     | 940M/2.12G [00:16<00:15, 82.6MB/s]

 59%|█████▊    | 1.00G/1.70G [00:16<00:09, 76.3MB/s]
 44%|████▍     | 949M/2.12G [00:17<00:18, 70.7MB/s]
 44%|████▍     | 957M/2.12G [00:17<00:18, 67.3MB/s]

 59%|█████▉    | 1.01G/1.70G [00:17<00:11, 64.5MB/s]
 45%|████▍     | 964M/2.12G [00:17<00:19, 65.9MB/s]

 60%|█████▉    | 1.02G/1.70G [00:17<00:11, 64.5MB/s]
 45%|████▍     | 971M/2.12G [00:17<00:19, 65.2MB/s]

 60%|██████    | 1.02G/1.70G [00:17<00:11, 63.6MB/s]
 45%|████▌     | 978M/2.12G [00:17<00:19, 63.6MB/s]

 61%|██████    | 1.03G/1.70G [00:17<00:11, 62.9MB/s]
 45%|████▌     | 985M/2.12G [00:17<00:19, 64.7MB/s]

 61%|██████    | 1.04G/1.70G [00:17<00:11, 62.8MB/s]
 46%|████▌     | 992M/2.12G [00:17<00:18, 65.3MB/s]

 61%|██████▏   | 1.04G/1.70G [00:17<00:10, 64.5MB/s]
 46%|████▌     | 0.98G/2.12G [00:17<00:17, 68.9MB/s]

 62%|██████▏   | 1.05G/1.70G [00:17<00:10, 66.5MB/s]
 46%|████▋     | 0.98G/2.12G [00:18<00:17, 69.8MB/s]

 62%|██████▏   | 1.06G/1.70G [00:17<00:10, 67.7MB/s]
 47%|████▋     | 0.99G/2.12G [00:18<00:16, 73.2MB/s]

 63%|██████▎   | 1.07G/1.70G [00:18<00:09, 69.0MB/s]
 47%|████▋     | 1.00G/2.12G [00:18<00:16, 72.2MB/s]

 63%|██████▎   | 1.07G/1.70G [00:18<00:09, 68.5MB/s]
 48%|████▊     | 1.01G/2.12G [00:18<00:16, 73.8MB/s]

 63%|██████▎   | 1.08G/1.70G [00:18<00:09, 69.1MB/s]
 48%|████▊     | 1.01G/2.12G [00:18<00:18, 63.5MB/s]

 64%|██████▍   | 1.09G/1.70G [00:18<00:10, 61.4MB/s]
 48%|████▊     | 1.02G/2.12G [00:18<00:20, 57.1MB/s]

 64%|██████▍   | 1.09G/1.70G [00:18<00:11, 55.9MB/s]
 49%|████▊     | 1.03G/2.12G [00:18<00:20, 57.9MB/s]

 65%|██████▍   | 1.10G/1.70G [00:18<00:11, 56.5MB/s]
 49%|████▉     | 1.04G/2.12G [00:18<00:18, 63.5MB/s]

 65%|██████▍   | 1.10G/1.70G [00:18<00:11, 57.4MB/s]
 49%|████▉     | 1.04G/2.12G [00:19<00:18, 63.4MB/s]

 65%|██████▌   | 1.11G/1.70G [00:18<00:10, 58.5MB/s]
 50%|████▉     | 1.05G/2.12G [00:19<00:17, 64.7MB/s]

 66%|██████▌   | 1.12G/1.70G [00:19<00:10, 61.8MB/s]
 50%|████▉     | 1.06G/2.12G [00:19<00:16, 67.6MB/s]

 66%|██████▌   | 1.12G/1.70G [00:19<00:09, 63.1MB/s]
 50%|█████     | 1.06G/2.12G [00:19<00:15, 70.6MB/s]

 67%|██████▋   | 1.13G/1.70G [00:19<00:08, 69.1MB/s]
 51%|█████     | 1.07G/2.12G [00:19<00:15, 72.6MB/s]

 67%|██████▋   | 1.14G/1.70G [00:19<00:08, 71.4MB/s]
 51%|█████     | 1.08G/2.12G [00:19<00:15, 73.6MB/s]

 68%|██████▊   | 1.15G/1.70G [00:19<00:07, 74.3MB/s]
 51%|█████▏    | 1.09G/2.12G [00:19<00:14, 76.0MB/s]

 68%|██████▊   | 1.16G/1.70G [00:19<00:07, 74.0MB/s]
 52%|█████▏    | 1.10G/2.12G [00:19<00:15, 69.1MB/s]

 69%|██████▊   | 1.17G/1.70G [00:19<00:08, 69.9MB/s]
 52%|█████▏    | 1.10G/2.12G [00:19<00:15, 70.8MB/s]

 69%|██████▉   | 1.17G/1.70G [00:19<00:08, 70.0MB/s]
 52%|█████▏    | 1.11G/2.12G [00:20<00:16, 66.4MB/s]

 69%|██████▉   | 1.18G/1.70G [00:19<00:09, 60.2MB/s]

 70%|██████▉   | 1.19G/1.70G [00:20<00:08, 62.5MB/s]
 53%|█████▎    | 1.12G/2.12G [00:20<00:19, 54.8MB/s]

 70%|███████   | 1.19G/1.70G [00:20<00:08, 63.8MB/s]
 53%|█████▎    | 1.12G/2.12G [00:20<00:18, 57.7MB/s]

 71%|███████   | 1.20G/1.70G [00:20<00:08, 64.8MB/s]
 53%|█████▎    | 1.13G/2.12G [00:20<00:17, 61.3MB/s]

 71%|███████   | 1.21G/1.70G [00:20<00:07, 68.3MB/s]
 54%|█████▍    | 1.14G/2.12G [00:20<00:16, 65.0MB/s]

 71%|███████▏  | 1.21G/1.70G [00:20<00:07, 66.1MB/s]
 54%|█████▍    | 1.15G/2.12G [00:20<00:16, 64.4MB/s]

 72%|███████▏  | 1.22G/1.70G [00:20<00:07, 66.4MB/s]
 54%|█████▍    | 1.15G/2.12G [00:20<00:16, 63.0MB/s]

 72%|███████▏  | 1.23G/1.70G [00:20<00:08, 61.2MB/s]
 55%|█████▍    | 1.16G/2.12G [00:20<00:16, 61.6MB/s]

 73%|███████▎  | 1.23G/1.70G [00:20<00:08, 58.0MB/s]
 55%|█████▌    | 1.17G/2.12G [00:21<00:16, 61.0MB/s]
 55%|█████▌    | 1.17G/2.12G [00:22<01:11, 14.1MB/s]

 73%|███████▎  | 1.24G/1.70G [00:22<00:37, 13.4MB/s]
 56%|█████▌    | 1.18G/2.12G [00:22<00:51, 19.6MB/s]

 73%|███████▎  | 1.25G/1.70G [00:22<00:23, 20.3MB/s]
 56%|█████▌    | 1.19G/2.12G [00:22<00:42, 23.4MB/s]

 74%|███████▍  | 1.25G/1.70G [00:22<00:19, 24.2MB/s]

 74%|███████▍  | 1.26G/1.70G [00:22<00:13, 34.0MB/s]

 75%|███████▌  | 1.28G/1.70G [00:22<00:08, 50.6MB/s]

 76%|███████▌  | 1.29G/1.70G [00:22<00:07, 59.8MB/s]
 57%|█████▋    | 1.20G/2.12G [00:23<00:35, 27.6MB/s]
 57%|█████▋    | 1.21G/2.12G [00:23<00:26, 36.3MB/s]

 76%|███████▋  | 1.30G/1.70G [00:23<00:07, 57.6MB/s]

 77%|███████▋  | 1.31G/1.70G [00:23<00:06, 62.4MB/s]
 57%|█████▋    | 1.21G/2.12G [00:23<00:27, 35.1MB/s]
 58%|█████▊    | 1.22G/2.12G [00:23<00:24, 38.9MB/s]

 77%|███████▋  | 1.31G/1.70G [00:23<00:07, 53.2MB/s]
 58%|█████▊    | 1.23G/2.12G [00:23<00:22, 43.0MB/s]

 78%|███████▊  | 1.32G/1.70G [00:23<00:06, 61.9MB/s]

 78%|███████▊  | 1.33G/1.70G [00:23<00:05, 68.7MB/s]
 58%|█████▊    | 1.24G/2.12G [00:23<00:20, 46.8MB/s]

 79%|███████▉  | 1.34G/1.70G [00:23<00:05, 72.2MB/s]
 59%|█████▉    | 1.24G/2.12G [00:23<00:20, 46.0MB/s]

 80%|███████▉  | 1.35G/1.70G [00:23<00:04, 77.2MB/s]

 80%|███████▉  | 1.36G/1.70G [00:23<00:04, 74.6MB/s]
 59%|█████▉    | 1.25G/2.12G [00:24<00:19, 47.6MB/s]

 81%|████████  | 1.37G/1.70G [00:24<00:04, 79.9MB/s]
 60%|█████▉    | 1.26G/2.12G [00:24<00:16, 56.2MB/s]

 81%|████████  | 1.38G/1.70G [00:24<00:03, 88.6MB/s]
 60%|█████▉    | 1.27G/2.12G [00:24<00:18, 48.6MB/s]

 82%|████████▏ | 1.39G/1.70G [00:24<00:03, 87.9MB/s]
 60%|██████    | 1.27G/2.12G [00:24<00:16, 55.1MB/s]

 82%|████████▏ | 1.40G/1.70G [00:24<00:03, 81.4MB/s]
 61%|██████    | 1.28G/2.12G [00:24<00:14, 62.4MB/s]

 83%|████████▎ | 1.41G/1.70G [00:24<00:03, 82.2MB/s]
 61%|██████    | 1.29G/2.12G [00:24<00:14, 62.8MB/s]

 83%|████████▎ | 1.42G/1.70G [00:24<00:03, 85.9MB/s]
 61%|██████▏   | 1.30G/2.12G [00:24<00:13, 64.0MB/s]

 84%|████████▍ | 1.43G/1.70G [00:24<00:02, 99.6MB/s]

 85%|████████▍ | 1.45G/1.70G [00:24<00:02, 109MB/s] 
 62%|██████▏   | 1.31G/2.12G [00:25<00:15, 57.2MB/s]

 86%|████████▌ | 1.46G/1.70G [00:25<00:02, 110MB/s]
 62%|██████▏   | 1.32G/2.12G [00:25<00:12, 69.9MB/s]

 86%|████████▌ | 1.47G/1.70G [00:25<00:02, 110MB/s]
 63%|██████▎   | 1.33G/2.12G [00:25<00:11, 74.1MB/s]

 87%|████████▋ | 1.48G/1.70G [00:25<00:02, 102MB/s]
 63%|██████▎   | 1.34G/2.12G [00:25<00:10, 82.6MB/s]

 87%|████████▋ | 1.49G/1.70G [00:25<00:02, 96.1MB/s]
 64%|██████▎   | 1.34G/2.12G [00:25<00:09, 85.4MB/s]

 88%|████████▊ | 1.50G/1.70G [00:25<00:02, 106MB/s] 
 64%|██████▍   | 1.35G/2.12G [00:25<00:11, 71.7MB/s]

 89%|████████▉ | 1.52G/1.70G [00:25<00:01, 119MB/s]
 65%|██████▍   | 1.37G/2.12G [00:25<00:09, 84.5MB/s]

 90%|████████▉ | 1.53G/1.70G [00:25<00:01, 117MB/s]
 65%|██████▍   | 1.37G/2.12G [00:25<00:11, 69.7MB/s]

 91%|█████████ | 1.54G/1.70G [00:25<00:01, 99.4MB/s]
 65%|██████▌   | 1.38G/2.12G [00:26<00:10, 74.7MB/s]

 91%|█████████ | 1.55G/1.70G [00:25<00:01, 99.8MB/s]
 66%|██████▌   | 1.39G/2.12G [00:26<00:10, 74.7MB/s]

 92%|█████████▏| 1.56G/1.70G [00:26<00:01, 102MB/s] 
 66%|██████▋   | 1.40G/2.12G [00:26<00:09, 83.9MB/s]

 93%|█████████▎| 1.57G/1.70G [00:26<00:01, 108MB/s]
 67%|██████▋   | 1.41G/2.12G [00:26<00:08, 91.4MB/s]

 93%|█████████▎| 1.58G/1.70G [00:26<00:01, 100MB/s]
 67%|██████▋   | 1.42G/2.12G [00:26<00:08, 90.9MB/s]

 94%|█████████▎| 1.59G/1.70G [00:26<00:01, 87.1MB/s]
 68%|██████▊   | 1.43G/2.12G [00:26<00:08, 86.6MB/s]

 94%|█████████▍| 1.60G/1.70G [00:26<00:01, 82.8MB/s]
 68%|██████▊   | 1.44G/2.12G [00:26<00:10, 70.8MB/s]

 95%|█████████▍| 1.61G/1.70G [00:26<00:01, 78.2MB/s]
 68%|██████▊   | 1.45G/2.12G [00:26<00:09, 73.6MB/s]

 95%|█████████▌| 1.62G/1.70G [00:26<00:01, 83.3MB/s]
 69%|██████▉   | 1.46G/2.12G [00:27<00:08, 80.8MB/s]

 96%|█████████▌| 1.63G/1.70G [00:26<00:00, 85.3MB/s]
 69%|██████▉   | 1.47G/2.12G [00:27<00:08, 81.4MB/s]

 96%|█████████▋| 1.64G/1.70G [00:27<00:00, 87.5MB/s]
 70%|██████▉   | 1.48G/2.12G [00:27<00:07, 86.9MB/s]

 97%|█████████▋| 1.65G/1.70G [00:27<00:00, 70.2MB/s]
 70%|███████   | 1.48G/2.12G [00:27<00:10, 67.4MB/s]

 97%|█████████▋| 1.66G/1.70G [00:27<00:00, 75.5MB/s]
 71%|███████   | 1.49G/2.12G [00:27<00:09, 71.1MB/s]

 98%|█████████▊| 1.67G/1.70G [00:27<00:00, 82.3MB/s]
 71%|███████   | 1.50G/2.12G [00:27<00:08, 74.9MB/s]

 99%|█████████▊| 1.68G/1.70G [00:27<00:00, 90.1MB/s]
 71%|███████▏  | 1.51G/2.12G [00:27<00:08, 75.7MB/s]

 99%|█████████▉| 1.69G/1.70G [00:27<00:00, 89.7MB/s]
 72%|███████▏  | 1.52G/2.12G [00:27<00:08, 79.7MB/s]

100%|██████████| 1.70G/1.70G [00:27<00:00, 65.5MB/s]

 72%|███████▏  | 1.53G/2.12G [00:28<00:08, 76.6MB/s]
 73%|███████▎  | 1.54G/2.12G [00:28<00:06, 98.0MB/s]
 74%|███████▎  | 1.56G/2.12G [00:28<00:05, 112MB/s] 
 74%|███████▍  | 1.57G/2.12G [00:28<00:05, 105MB/s]
 75%|███████▍  | 1.58G/2.12G [00:28<00:05, 107MB/s]
 75%|███████▌  | 1.59G/2.12G [00:28<00:06, 89.3MB/s]
 76%|███████▌  | 1.61G/2.12G [00:28<00:05, 101MB/s] 
 77%|███████▋  | 1.62G/2.12G [00:28<00:04, 112MB/s]
 77%|███████▋  | 1.63G/2.12G [00:29<00:05, 92.3MB/s]
 78%|███████▊  | 1.64G/2.12G [00:29<00:06, 78.4MB/s]
 78%|███████▊  | 1.66G/2.12G [00:29<00:05, 95.5MB/s]
 79%|███████▉  | 1.67G/2.12G [00:29<00:04, 101MB/s] 
 80%|███████▉  | 1.69G/2.12G [00:29<00:04, 114MB/s]
 80%|████████  | 1.70G/2.12G [00:29<00:03, 124MB/s]
 81%|████████  | 1.71G/2.12G [00:29<00:03, 122MB/s]
 82%|████████▏ | 1.73G/2.12G [00:30<00:04, 95.5MB/s]
 82%|████████▏ | 1.74G/2.12G [00:30<00:03, 104MB/s] 
 83%|████████▎ | 1.75G/2.12G [00:30<00:03, 103MB/s]
 83%|████████▎ | 1.76G/2.12G [00:30<00:04, 76.8MB/s]
 84%|████████▍ | 1.78G/2.12G [00:30<00:03, 96.8MB/s]
 85%|████████▍ | 1.79G/2.12G [00:30<00:03, 102MB/s] 
 85%|████████▌ | 1.80G/2.12G [00:31<00:04, 69.5MB/s]
 86%|████████▌ | 1.82G/2.12G [00:31<00:03, 92.2MB/s]
 87%|████████▋ | 1.83G/2.12G [00:31<00:03, 91.5MB/s]
 87%|████████▋ | 1.84G/2.12G [00:31<00:02, 98.1MB/s]
 88%|████████▊ | 1.85G/2.12G [00:31<00:03, 85.1MB/s]
 88%|████████▊ | 1.86G/2.12G [00:31<00:03, 88.0MB/s]
 89%|████████▊ | 1.87G/2.12G [00:31<00:02, 88.9MB/s]
 89%|████████▉ | 1.88G/2.12G [00:32<00:02, 93.5MB/s]
 90%|████████▉ | 1.90G/2.12G [00:32<00:02, 106MB/s] 
 90%|█████████ | 1.91G/2.12G [00:32<00:01, 111MB/s]
 91%|█████████ | 1.92G/2.12G [00:32<00:02, 98.7MB/s]
 91%|█████████▏| 1.93G/2.12G [00:32<00:01, 107MB/s] 
 92%|█████████▏| 1.95G/2.12G [00:32<00:01, 104MB/s]
 93%|█████████▎| 1.96G/2.12G [00:32<00:01, 117MB/s]
 93%|█████████▎| 1.97G/2.12G [00:32<00:01, 106MB/s]
 94%|█████████▍| 1.98G/2.12G [00:33<00:01, 92.6MB/s]
 94%|█████████▍| 1.99G/2.12G [00:33<00:01, 86.4MB/s]
 95%|█████████▍| 2.00G/2.12G [00:33<00:01, 68.3MB/s]
 95%|█████████▌| 2.02G/2.12G [00:33<00:01, 84.0MB/s]
 96%|█████████▌| 2.03G/2.12G [00:33<00:01, 91.6MB/s]
 96%|█████████▋| 2.04G/2.12G [00:33<00:00, 96.2MB/s]
 97%|█████████▋| 2.05G/2.12G [00:33<00:00, 108MB/s] 
 98%|█████████▊| 2.06G/2.12G [00:33<00:00, 89.9MB/s]
 98%|█████████▊| 2.08G/2.12G [00:34<00:00, 103MB/s] 
 99%|█████████▉| 2.09G/2.12G [00:34<00:00, 115MB/s]
100%|██████████| 2.12G/2.12G [00:34<00:00, 66.0MB/s]
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1
  1. Verifique o local dos pesos do modelo e do tokenizador e, em seguida, defina as variáveis do caminho. O diretório tokenizador estará no diretório principal em que você fez o download do modelo, enquanto os pesos do modelo estarão em um subdiretório. Exemplo:
  • O arquivo tokenizer.model estará em /LOCAL/PATH/TO/gemma/flax/2b-it/2.
  • O checkpoint do modelo estará em /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1/gemma2-2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1/tokenizer.model

Realizar amostragem/inferência

  1. Carregue e formate o checkpoint do modelo Gemma com o método gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Carregue o tokenizador do Gemma, construído usando sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Para carregar automaticamente a configuração correta do checkpoint do modelo Gemma, use gemma.transformer.TransformerConfig. O argumento cache_size é o número de etapas de tempo no cache Transformer do Gemma. Em seguida, instancie o modelo Gemma como transformer com gemma.transformer.Transformer (herdado de flax.linen.Module).
.
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Crie um sampler com gemma.sampler.Sampler sobre o checkpoint/peso do modelo Gemma e o tokenizador:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Escreva um comando em input_batch e faça inferências. Você pode ajustar o total_generation_steps, que é o número de etapas realizadas ao gerar uma resposta. Este exemplo usa 100 para preservar a memória do host.
.
prompt = [
    "what is JAX in 3 bullet points?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=128,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:
what is JAX in 3 bullet points?
Output:


* **High-performance numerical computation:** JAX leverages the power of GPUs and TPUs to accelerate complex mathematical operations, making it ideal for scientific computing, machine learning, and data analysis.
* **Automatic differentiation:** JAX provides automatic differentiation capabilities, allowing you to compute gradients and optimize models efficiently. This simplifies the process of training deep learning models.
* **Functional programming:** JAX embraces functional programming principles, promoting code readability and maintainability. It offers a flexible and expressive syntax for defining and manipulating data. 


<end_of_turn>
  1. (Opcional) Execute esta célula para liberar memória se você já concluiu o notebook e quer tentar outro comando. Em seguida, você pode instanciar o sampler novamente na etapa 3 e personalizar e executar o comando na etapa 4.
del sampler

Saiba mais