Ajustar modelos Gemma no Keras usando a LoRA

Confira no ai.google.dev Executar no Google Colab Abrir na Vertex AI Veja o código-fonte no GitHub

Visão geral

O Gemma é uma família de modelos abertos leves e de última geração criados com base na mesma pesquisa e tecnologia usadas para criar os modelos do Gemini.

Os modelos de linguagem grandes (LLMs), como o Gemma, têm se mostrado eficazes em várias tarefas de PLN. Primeiro, um LLM é pré-treinado em um grande corpus de texto de maneira autossupervisionada. O pré-treinamento ajuda os LLMs a aprenderem conhecimentos de uso geral, como relações estatísticas entre palavras. Um LLM pode ser ajustado com dados específicos do domínio para realizar tarefas posteriores, como a análise de sentimento.

Os LLMs são extremamente grandes (parâmetros na ordem de bilhões). O ajuste fino completo (que atualiza todos os parâmetros no modelo) não é necessário para a maioria das aplicações, porque os conjuntos de dados de ajuste fino típicos são relativamente muito menores do que os conjuntos de dados de pré-treinamento.

A adaptação de baixo escalão (LoRA) é uma técnica de ajuste fino que reduz bastante o número de parâmetros treináveis para tarefas posteriores congelando os pesos do modelo e inserindo um número menor de novos pesos. Isso torna o treinamento com LoRA muito mais rápido e eficiente em termos de memória, além de produzir pesos de modelo menores (algumas centenas de MB), mantendo a qualidade das saídas do modelo.

Este tutorial mostra como usar o KerasNLP para realizar o ajuste fino do LoRA em um modelo Gemma 2B usando o conjunto de dados Databricks Dolly 15k. Este conjunto de dados contém 15.000 pares de comando / resposta gerados por humanos de alta qualidade,especificamente projetados para o ajuste de LLMs.

Configuração

Acessar o Gemma

Para concluir este tutorial, primeiro você precisa seguir as instruções de configuração em Configuração do Gemma. As instruções de configuração do Gemma mostram como fazer o seguinte:

  • Acesse o Gemma em kaggle.com.
  • Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo Gemma 2B.
  • Gere e configure um nome de usuário e uma chave de API do Kaggle.

Depois de concluir a configuração do Gemma, vá para a próxima seção, em que você definirá variáveis de ambiente para o ambiente do Colab.

Selecione o ambiente de execução

Para concluir este tutorial, você precisa ter um ambiente de execução do Colab com recursos suficientes para executar o modelo do Gemma. Nesse caso, você pode usar uma GPU T4:

  1. No canto superior direito da janela do Colab, selecione ▾ (Opções de conexão adicionais).
  2. Selecione Mudar o tipo de ambiente de execução.
  3. Em Acelerador de hardware, selecione GPU T4.

Configurar a chave de API

Para usar o Gemma, você precisa informar seu nome de usuário e uma chave de API do Kaggle.

Para gerar uma chave de API do Kaggle, acesse a guia Account do seu perfil de usuário no Kaggle e selecione Create New Token. Isso vai acionar o download de um arquivo kaggle.json que contém suas credenciais de API.

No Colab, selecione Secrets (chaves secretas) (🔑) no painel à esquerda e adicione seu nome de usuário e chave da API do Kaggle. Armazene seu nome de usuário com o nome KAGGLE_USERNAME e sua chave de API com o nome KAGGLE_KEY.

Defina as variáveis de ambiente

Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instalar dependências

Instale o Keras, o KerasNLP e outras dependências.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Selecione um back-end

O Keras é uma API de aprendizado profundo de alto nível e com vários frameworks projetada para ser simples e fácil de usar. Com o Keras 3, é possível executar fluxos de trabalho em um dos três back-ends: TensorFlow, JAX ou PyTorch.

Neste tutorial, configure o back-end para JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Importar pacotes

Importe o Keras e o KerasNLP.

import keras
import keras_nlp

Carregar conjunto de dados

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Pré-processar os dados. Este tutorial usa um subconjunto de 1.000 exemplos de treinamento para executar o notebook mais rápido. Use mais dados de treinamento para fazer ajustes com maior qualidade.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Carregar modelo

O KerasNLP fornece implementações de muitas arquiteturas de modelos conhecidas. Neste tutorial, você vai criar um modelo usando GemmaCausalLM, um modelo completo do Gemma para modelagem de linguagem causal. Um modelo de linguagem causal prevê o próximo token com base nos tokens anteriores.

Crie o modelo usando o método from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

O método from_preset instancia o modelo usando uma arquitetura e pesos predefinidos. No código acima, a string "gemma2_2b_en" especifica a arquitetura predefinida, um modelo do Gemma com 2 bilhões de parâmetros.

Inferência antes do ajuste fino

Nesta seção, você vai consultar o modelo com vários comandos para ver como ele responde.

Europe Trip Prompt

Consultar o modelo para receber sugestões sobre o que fazer em uma viagem à Europa.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

O modelo responde com dicas genéricas sobre como planejar uma viagem.

ELI5 Photosynthesis Prompt

Peça para o modelo explicar a fotossíntese de forma simples o suficiente para que uma criança de 5 anos entenda.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

A resposta do modelo contém palavras que podem não ser fáceis de entender para uma criança, como clorofila.

Ajuste fino da LoRA

Para obter melhores respostas do modelo, ajuste-o com Adaptação de baixa classificação (LoRA) usando o conjunto de dados Databricks Dolly 15k.

A classificação LoRA determina a dimensionalidade das matrizes treináveis que são adicionadas aos pesos originais do LLM. Ele controla a expressividade e a precisão dos ajustes de ajuste fino.

Uma classificação mais alta significa que mudanças mais detalhadas são possíveis, mas também significa mais parâmetros treináveis. Uma classificação mais baixa significa menos overhead computacional, mas uma adaptação potencialmente menos precisa.

Este tutorial usa uma classificação LoRA de 4. Na prática, comece com uma classificação relativamente pequena (como 4, 8, 16). Isso é eficiente computacionalmente para experimentação. Treine seu modelo com essa classificação e avalie a melhoria no desempenho da tarefa. Aumente gradualmente a classificação nas tentativas subsequentes e veja se isso melhora ainda mais o desempenho.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

A ativação da LoRA reduz significativamente o número de parâmetros treináveis (de 2,6 bilhões para 2,9 milhões).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

Observação sobre o ajuste fino de precisão mista em GPUs NVIDIA

A precisão total é recomendada para ajustes finos. Ao ajustar as GPUs NVIDIA, é possível usar a precisão mista (keras.mixed_precision.set_global_policy('mixed_bfloat16')) para acelerar o treinamento com efeito mínimo na qualidade do treinamento. O ajuste fino de precisão mista consome mais memória, então é útil apenas em GPUs maiores.

Para inferência, a meia-precisão (keras.config.set_floatx("bfloat16")) funciona e economiza memória, mas a precisão mista não é aplicável.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Inferência após o ajuste fino

Depois do ajuste, as respostas seguem as instruções do comando.

Europe Trip Prompt

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

Agora, o modelo recomenda lugares para visitar na Europa.

Comando de fotossíntese ELI5

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

O modelo agora explica a fotossíntese de forma mais simples.

Para fins de demonstração, este tutorial ajusta o modelo em um pequeno subconjunto do conjunto de dados para apenas uma época e com um valor de classificação LoRA baixo. Para receber respostas melhores do modelo ajustado, teste:

  1. Aumentar o tamanho do conjunto de dados de ajuste fino
  2. Treinamento para mais etapas (épocas)
  3. Como definir uma classificação LoRA mais alta
  4. Modificar os valores de hiperparâmetros, como learning_rate e weight_decay.

Resumo e próximas etapas

Este tutorial abordou o ajuste fino do LoRA em um modelo Gemma usando o KerasNLP. Confira a seguir os documentos: