| | Запустить в Google Colab | | | Посмотреть исходный код на GitHub |
В этом блокноте показано, как выполнить тонкую настройку PaliGemma для задачи обработки изображений и языка с использованием JAX . Тонкая настройка — это процесс, который может улучшить производительность вашей модели в конкретных задачах или помочь модели соответствовать определенным требованиям к выходным данным, когда одних инструкций недостаточно, а у вас есть набор примеров, демонстрирующих желаемые результаты. Модели на основе Gemma, такие как PaliGemma, требуют тонкой настройки для получения ожидаемых результатов.
Что находится в этом блокноте?
В этом блокноте используется эталонная реализация модели из big_vision и показано, как:
- Установите необходимые зависимости и загрузите контрольную точку модели PaliGemma и обучающие данные.
- Загрузите модель на графические процессоры (GPU).
- Подготовьте входные данные для модели для обучения и вывода результатов.
- Доработайте модель.
- Проверьте результат.
Обучающие данные для этого ноутбука состоят из 90 пар изображений и длинных подписей к ним. Чтобы обеспечить его работу в среде выполнения T4 Colab, вам потребуется выполнить тонкую настройку только слоев внимания языковой модели, а остальные параметры зафиксировать.
Этот пример предназначен исключительно для обучения. В реальных условиях объем данных, обучаемые параметры, этапы обучения и гиперпараметры, а также полученные результаты могут существенно отличаться.
Прежде чем начать
Прежде чем приступить к работе с этим блокнотом, вам следует ознакомиться с кодом на Python, а также с тем, как обучаются большие языковые модели (LLM). Знание JAX не является обязательным, но базовые знания о JAX (или аналогичных технологиях, таких как Keras) будут полезны при чтении примеров кода.
Настраивать
В следующих разделах описаны предварительные шаги для использования модели PaliGemma в ноутбуке, включая доступ к модели, получение ключа API и настройку среды выполнения ноутбука.
Получите доступ к PaliGemma
Перед первым использованием PaliGemma необходимо запросить доступ к модели через Kaggle, выполнив следующие шаги:
- Войдите в Kaggle или создайте новую учетную запись Kaggle, если у вас ее еще нет.
- Перейдите к карточке модели PaliGemma и нажмите «Запросить доступ» .
- Заполните форму согласия и примите условия.
Настройте свой API-ключ
Для использования PaliGemma необходимо указать ваше имя пользователя Kaggle и ключ API Kaggle.
Чтобы сгенерировать ключ API Kaggle, откройте страницу настроек в Kaggle и нажмите «Создать новый токен» . Это запустит загрузку файла kaggle.json , содержащего ваши учетные данные API.
Затем в Colab выберите «Секреты » (🔑) в левой панели и добавьте свое имя пользователя Kaggle и ключ API Kaggle. Сохраните свое имя пользователя под именем KAGGLE_USERNAME , а ключ API — под именем KAGGLE_KEY .
Выберите среду выполнения
Для выполнения этого руководства вам потребуется среда выполнения Colab с достаточными ресурсами для запуска модели PaliGemma. В данном случае вы можете использовать графический процессор T4:
- В правом верхнем углу окна Colab щелкните раскрывающееся меню ▾ (Дополнительные параметры подключения) .
- Выберите «Изменить тип среды выполнения» .
- В разделе «Аппаратный ускоритель» выберите графический процессор T4 .
Установите пакеты Python.
Запустите указанную ниже ячейку, чтобы установить KaggleHub.
pip install -U -q kagglehubУстановите переменные среды
Установите переменные среды и войдите в Kaggle.
import os
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
# This will prompt you to enter your Kaggle API token
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle… Kaggle credentials set. Kaggle credentials successfully validated.
Загрузите репозиторий big_vision и установите необходимые зависимости.
Загрузите репозиторий big_vision из GitHub в свой блокнот Colab и установите зависимости, связанные с big_vision , выполнив следующий код.
import os
import sys
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
!git clone --quiet --branch=main --depth=1 \
https://github.com/google-research/big_vision big_vision_repo
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 2.8 MB/s eta 0:00:00
Импортируйте JAX и другие зависимости.
Для работы PaliGemma необходимо импортировать JAX и другие зависимости, такие как TensorFlow и NumPy.
import base64
import functools
import html
import io
import os
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from IPython.core.display import display, HTML
from PIL import Image
# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.extend.backend.get_backend()
print(f"JAX version: {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices: {jax.device_count()}")
JAX version: 0.7.2 JAX platform: gpu JAX devices: 1
Загрузите и настройте модель.
На этом этапе вы загрузите контрольную точку модели и настроите её, чтобы позже можно было выполнить её тонкую настройку. На этом этапе показано, как переместить параметры модели в память TPU, что полезно для тонкой настройки моделей на устройствах с ограниченными ресурсами.
Загрузите контрольную точку модели.
PaliGemma включает в себя несколько вариантов моделей. Для этого урока вы будете использовать базовую модель весов JAX/FLAX PaliGemma 3B .
Загрузите контрольную точку модели с Kaggle, выполнив следующий код. Этот процесс займет несколько минут.
import os
import kagglehub
# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224" # Path to fetch from Kaggle.
# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"
if not os.path.exists(MODEL_PATH):
print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
print(f"Model path: {MODEL_PATH}")
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
print("Downloading the model tokenizer...")
!gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
print(f"Tokenizer path: {TOKENIZER_PATH}")
DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
print("Downloading the dataset...")
!gsutil -m -q cp -n -r gs://longcap100/ .
print(f"Data path: {DATA_DIR}")
Downloading the checkpoint from Kaggle, this could take a few minutes.... Downloading to /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz... 100%|██████████| 5.65G/5.65G [00:54<00:00, 112MB/s] Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz Downloading the model tokenizer... Copying gs://big_vision/paligemma_tokenizer.model... - [1 files][ 4.1 MiB/ 4.1 MiB] Operation completed over 1 objects/4.1 MiB. Tokenizer path: ./paligemma_tokenizer.model Downloading the dataset... Data path: ./longcap100
Настройте модель
Пришло время приступить к настройке модели, которую вы собираетесь использовать.
Для этого ноутбука вам потребуется разместить вашу модель на графическом процессоре T4. Ограниченные ресурсы, такие как пространство, означают, что вам необходимо тщательно продумать конфигурацию вашей модели.
Если вы будете точно настраивать каждый параметр, ваша модель не сможет работать в среде ноутбука. Поэтому в этой части ноутбука вы настроите свою модель таким образом, чтобы она могла «заморозить» некоторые параметры и точно настраивать только те параметры, которые действительно необходимы для получения точных результатов. В LLM-моделях параметры считаются «замороженными» , когда они больше не используются активно для обучения модели.
Для настройки вашей модели вам необходимо:
- Инициализируйте
model_configкакFrozenConfigDict, чтобы можно было заморозить некоторые параметры и снизить потребление памяти. - Инициализируйте экземпляр класса PaliGemma
Model, используя параметрmodel_configв качестве конфигурации. - Загрузите параметры модели в оперативную память.
- Определите функцию
decodeдля выборки выходных данных из модели.
Выполнение кода в этой ячейке занимает около минуты.
# Define model
# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())
Переместите параметры модели в память GPU/TPU.
Теперь необходимо переместить параметры модели в память GPU/TPU. Сначала распределите параметры между доступными графическими процессорами, затем загрузите их. В данном случае параметры будут загружаться последовательно. Этот процесс занимает больше времени, чем одновременная загрузка, но требует больше оперативной памяти, чем доступно в этом ноутбуке.
Наконец, выведите все параметры, чтобы увидеть, к какому типу преобразован каждый отдельный параметр. Замороженные параметры сохраняются как float16 , а обучаемые параметры — как float32 . При просмотре списка вы увидите, что большинство параметров заморожены и имеют float16 .
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param): # pylint: disable=unused-argument
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
# Cast others to float16, since some GPUs don't support bf16.
return jax.tree.map(lambda p, m: p.astype(jnp.float32)
if m else p.astype(jnp.float16),
params, trainable)
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
params[idx] = big_vision.utils.reshard(params[idx], sharding)
params[idx] = maybe_cast_to_f32(params[idx], trainable)
params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# Print params to show what the model is made of.
def parameter_overview(params):
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)
== Model params == img/Transformer/encoder_norm/bias (1152,) float16 img/Transformer/encoder_norm/scale (1152,) float16 img/Transformer/encoderblock/LayerNorm_0/bias (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_0/scale (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_1/bias (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_1/scale (27, 1152) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias (27, 4304) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel (27, 1152, 4304) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias (27, 1152) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel (27, 4304, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel (27, 1152, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias (27, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel (27, 16, 72, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel (27, 1152, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel (27, 1152, 16, 72) float16 img/embedding/bias (1152,) float16 img/embedding/kernel (14, 14, 3, 1152) float16 img/head/bias (2304,) float16 img/head/kernel (1152, 2304) float16 img/pos_embedding (1, 256, 1152) float16 llm/embedder/input_embedding (257152, 2304) float16 llm/final_norm/scale (2304,) float16 llm/layers/attn/attn_vec_einsum/w (26, 8, 256, 2304) float32 llm/layers/attn/kv_einsum/w (26, 2, 4, 2304, 256) float32 llm/layers/attn/q_einsum/w (26, 8, 2304, 256) float32 llm/layers/mlp/gating_einsum (26, 2, 2304, 9216) float16 llm/layers/mlp/linear (26, 9216, 2304) float16 llm/layers/post_attention_norm/scale (26, 2304) float16 llm/layers/post_ffw_norm/scale (26, 2304) float16 llm/layers/pre_attention_norm/scale (26, 2304) float16 llm/layers/pre_ffw_norm/scale (26, 2304) float16
Приготовьтесь к настройке модели.
Теперь, когда ваша модель настроена, вы можете её оптимизировать. На этом этапе вы создадите входные данные для вашей модели, а также итераторы для обучения и проверки, просмотрите примеры обучения и определите циклы обучения и проверки.
Создание входных данных модели
Используемая вами контрольная точка модели уже обучена на изображениях с различными соотношениями сторон, уменьшенных до размера 224x224 пикселей, и способна обрабатывать токенизированный текст.
Приведённый ниже код определяет три функции, которые вы будете использовать на следующем шаге для создания входных данных модели:
-
preprocess_image: Нормализует данные изображения. В данном случае предварительная обработка преобразует переданное изображение в оттенки серого, удаляет альфа-канал и изменяет размер переданного изображения до размера, необходимого модели для входных изображений (224x224 пикселей). -
preprocess_tokens: Разделяет токены и добавляет флаги, указывающие, является ли токен префиксным или суффиксным. Эти флаги будут использоваться позже в коде, на этапе обучения и в цикле оценки. -
postprocess_tokens: Удаляет все токены, оставшиеся на уровне или после токена конца последовательности (EOS), и возвращает оставшиеся декодированные токены.
def preprocess_image(image, size=224):
# Model has been trained to handle images of different aspects ratios
# resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
# options are helpful to improve quality in some tasks.
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]
def preprocess_tokens(prefix, suffix=None, seqlen=None):
# Model has been trained to handle tokenized text composed of a prefix with
# full attention and a suffix with causal attention.
separator = "\n"
tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if it's a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return tokenizer.decode(tokens)
Создайте итераторы для обучения и проверки.
Создайте два итератора:
- Итератор для обучения , позволяющий процессу обучения обрабатывать данные по частям, а не все сразу.
- Это позволяет выполнить предварительную обработку данных перед использованием.
- Итератор валидации , позволяющий процессу обучения итерировать по набору данных валидации, чтобы увидеть, насколько хорошо настроенная модель соответствует предоставленным результатам.
SEQLEN = 128
train_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_train90.jsonl"),
fopen_keys={"image": DATA_DIR})
val_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_val10.jsonl"),
fopen_keys={"image": DATA_DIR})
def train_data_iterator():
"""Never ending iterator over training examples."""
# Shuffle examples and repeat so one can train for many epochs.
dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
for example in dataset.as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
suffix = example["suffix"].decode().lower()
tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_loss": np.asarray(mask_loss),
}
def validation_data_iterator():
"""Single iterator over validation examples."""
for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input),
}
Посмотреть примеры обучения
В этом блокноте обучающие данные содержат 90 изображений, каждое из которых сопровождается подробным описанием того, что изображено на картинке.
Приведённый ниже код выводит случайную выборку изображений с их описаниями из обучающего набора данных, чтобы вы могли увидеть, как выглядят изображения и описания, на которых обучалась ваша модель. Каждое изображение отображается в формате JPEG размером 128x128 пикселей, а описание выводится рядом с изображением справа.
def render_inline(image, resize=(128, 128)):
"""Convert image into inline html."""
image = Image.fromarray(image)
image.resize(resize)
with io.BytesIO() as buffer:
image.save(buffer, format='jpeg')
image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
return f"data:image/jpeg;base64,{image_b64}"
def render_example(image, caption):
image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]
return f"""
<div style="display: inline-flex; align-items: center; justify-content: center;">
<img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
<p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
</div>
"""
html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
caption = postprocess_tokens(example["text"]) # detokenize model input.
caption = caption[len("caption en\n"):] # strip prefix
html_out += render_example(example["image"], caption)
print("Training examples")
display(HTML(html_out))
Training examples
Определите циклы обучения и оценки.
Определите цикл обучения для тренировки модели на предоставленном наборе данных и цикл оценки для анализа всех примеров в проверочном наборе данных и выполнения прогнозов.
Определение цикла обучения
Функция update_fn определяет этап обучения. На этапе обучения вычисляется функция потерь для каждого примера, и к обучаемым параметрам применяется стохастический градиентный спуск (SGD).
Напомним, что ранее в блокноте вы добавили в функцию preprocess_tokens флаги, включающие mask_loss . Здесь вы будете использовать флаг mask_loss , чтобы исключить префиксные и дополненные токены из функции потерь. Без него расчет потерь будет искажен. Вам также необходимо нормализовать каждый пример, поскольку каждый из них имеет разное количество токенов. После исключения префиксных и дополненных токенов и нормализации примеров вы можете рассчитать потери для каждого примера.
Этап обучения также включает функцию применения алгоритма стохастического градиентного спуска (SGD) для оптимизации процесса обучения.
Определение цикла оценки
Функция make_predictions — это цикл оценки. Цикл оценки довольно прост, за исключением одного существенного изменения. Как вы помните из начала блокнота, в вашем обучающем наборе данных всего 90 примеров. Это очень небольшое количество обучающих примеров, и в итоге вашей модели не хватает примеров для размера пакета при запуске обучения. Это означает, что в цикле оценки вам необходимо дополнить пакет повторяющимися примерами.
Чтобы гарантировать, что ваш цикл оценки будет учитывать только фактические примеры, а не примеры с заполненными данными, необходимо применить маску к примерам с заполненными данными, которая исключит их из выходных данных.
# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]
def loss_fn(params):
text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
logp = jax.nn.log_softmax(text_logits, axis=-1)
# The model takes as input txts[:, :-1] but the loss is defined as predicting
# next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
# are part of the loss (e.g. prefix and padded tokens are not included).
mask_loss = batch["mask_loss"][:, 1:]
targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])
# Compute the loss per example. i.e. the mean of per token pplx.
# Since each example has a different number of tokens, normalize it.
token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.
example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.
example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.
# batch_loss: mean of per example loss.
return jnp.mean(example_loss)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Apply gradients to trainable params using SGD.
def apply_grad(param, gradient, trainable):
if not trainable: return param
return param - learning_rate * gradient
params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)
return params, loss
# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
# Construct a list of examples in the batch.
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
# Not enough examples to complete a batch. Pad by repeating last example.
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Make model predictions
tokens = decode({"params": params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [postprocess_tokens(t) for t in tokens]
# Append to html output.
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
Настройте модель
Теперь, когда вы все настроили и ознакомились с обучающими данными, пришло время наконец настроить модель. Приведенный ниже код запускает цикл обучения модели на 64 шага и выводит скорость обучения ( lr в выходных данных) и коэффициент потерь для каждого шага.
Каждые 16 шагов модель выводит свои предсказания на этом этапе обучения. Этот код выводит предсказания для одного и того же набора изображений, чтобы вы могли видеть, как способность модели предсказывать описания улучшается со временем.
На ранних этапах обучения могут возникать проблемы с описаниями, например, повторяющиеся предложения, когда модель застревает в цикле прогнозирования, или незаконченные предложения. По мере обучения точность прогнозов модели неуклонно повышается. К 64-му шагу прогнозы модели должны максимально точно соответствовать описаниям, предоставленным обучающими данными.
На обработку T4 TPU этот процесс занимает около 15 минут.
# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time
BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4
train_data_it = train_data_iterator()
sched_fn = big_vision.utils.create_learning_rate_schedule(
total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
decay_type="cosine", warmup_percent=0.10)
for step in range(1, TRAIN_STEPS+1):
# Make list of N training examples.
examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Training step and report training loss
learning_rate = sched_fn(step)
params, loss = update_fn(params, batch, learning_rate)
loss = jax.device_get(loss)
print(f"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}")
if (step % EVAL_STEPS) == 0:
print(f"Model predictions at step {step}")
html_out = ""
for image, caption in make_predictions(
validation_data_iterator(), num_examples=4, batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
step: 1/64 lr: 0.00500 loss: 3.6567 step: 2/64 lr: 0.01000 loss: 1.9762 step: 3/64 lr: 0.01500 loss: 1.6299 step: 4/64 lr: 0.02000 loss: 1.5651 step: 5/64 lr: 0.02500 loss: 1.9813 step: 6/64 lr: 0.03000 loss: 1.9996 step: 7/64 lr: 0.02998 loss: 1.8595 step: 8/64 lr: 0.02992 loss: 1.6479 step: 9/64 lr: 0.02981 loss: 1.3693 step: 10/64 lr: 0.02966 loss: 1.3423 step: 11/64 lr: 0.02947 loss: 1.2122 step: 12/64 lr: 0.02924 loss: 1.0602 step: 13/64 lr: 0.02897 loss: 1.1314 step: 14/64 lr: 0.02866 loss: 1.2612 step: 15/64 lr: 0.02831 loss: 1.0132 step: 16/64 lr: 0.02792 loss: 1.2126 Model predictions at step 16
step: 17/64 lr: 0.02750 loss: 1.0986 step: 18/64 lr: 0.02704 loss: 0.9461 step: 19/64 lr: 0.02655 loss: 1.2098 step: 20/64 lr: 0.02602 loss: 1.0513 step: 21/64 lr: 0.02546 loss: 1.0979 step: 22/64 lr: 0.02488 loss: 0.9739 step: 23/64 lr: 0.02426 loss: 0.9589 step: 24/64 lr: 0.02362 loss: 0.7053 step: 25/64 lr: 0.02296 loss: 0.7347 step: 26/64 lr: 0.02227 loss: 0.6990 step: 27/64 lr: 0.02156 loss: 0.6736 step: 28/64 lr: 0.02083 loss: 0.6642 step: 29/64 lr: 0.02009 loss: 0.6908 step: 30/64 lr: 0.01933 loss: 0.7257 step: 31/64 lr: 0.01856 loss: 0.6902 step: 32/64 lr: 0.01778 loss: 0.7054 Model predictions at step 32
step: 33/64 lr: 0.01699 loss: 0.7709 step: 34/64 lr: 0.01620 loss: 0.6653 step: 35/64 lr: 0.01540 loss: 0.3811 step: 36/64 lr: 0.01460 loss: 0.3104 step: 37/64 lr: 0.01380 loss: 0.4042 step: 38/64 lr: 0.01301 loss: 0.3904 step: 39/64 lr: 0.01222 loss: 0.3339 step: 40/64 lr: 0.01144 loss: 0.4156 step: 41/64 lr: 0.01067 loss: 0.4085 step: 42/64 lr: 0.00991 loss: 0.3083 step: 43/64 lr: 0.00917 loss: 0.3757 step: 44/64 lr: 0.00844 loss: 0.3813 step: 45/64 lr: 0.00773 loss: 0.3381 step: 46/64 lr: 0.00704 loss: 0.2057 step: 47/64 lr: 0.00638 loss: 0.1287 step: 48/64 lr: 0.00574 loss: 0.1711 Model predictions at step 48
step: 49/64 lr: 0.00512 loss: 0.1183 step: 50/64 lr: 0.00454 loss: 0.1154 step: 51/64 lr: 0.00398 loss: 0.1967 step: 52/64 lr: 0.00345 loss: 0.1497 step: 53/64 lr: 0.00296 loss: 0.1688 step: 54/64 lr: 0.00250 loss: 0.1878 step: 55/64 lr: 0.00208 loss: 0.1865 step: 56/64 lr: 0.00169 loss: 0.1655 step: 57/64 lr: 0.00134 loss: 0.0911 step: 58/64 lr: 0.00103 loss: 0.1836 step: 59/64 lr: 0.00076 loss: 0.1242 step: 60/64 lr: 0.00053 loss: 0.0814 step: 61/64 lr: 0.00034 loss: 0.0866 step: 62/64 lr: 0.00019 loss: 0.1295 step: 63/64 lr: 0.00008 loss: 0.1053 step: 64/64 lr: 0.00002 loss: 0.0730 Model predictions at step 64
CPU times: user 2min 18s, sys: 8.98 s, total: 2min 27s Wall time: 15min 45s
Выход
В этом блокноте для проверки используются всего 10 изображений. В обычном коде у вас, вероятно, было бы гораздо больше точек данных для проверки, но для этого блокнота выполните следующий код, чтобы сгенерировать описания для всех 10 изображений. После настройки модели эти описания должны быть очень похожи по форме и содержанию на описания, включенные в обучающие данные, которые вы рассматривали ранее в этом блокноте.
Выполните приведенный ниже код, чтобы сгенерировать описания для набора данных для проверки.
# The validation data consists of 10 images in a different domain than training
# data.
%%time
print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
Model predictions
CPU times: user 1.87 s, sys: 283 ms, total: 2.15 s Wall time: 39.3 s
Запустить в Google Colab
Посмотреть исходный код на GitHub