Доработайте PaliGemma с помощью JAX и Flax.

Посмотреть на ai.google.dev Запустить в Google Colab Запуск в Kaggle Открыть в Vertex AI Посмотреть исходный код на GitHub

В этом блокноте показано, как выполнить тонкую настройку PaliGemma для задачи обработки изображений и языка с использованием JAX . Тонкая настройка — это процесс, который может улучшить производительность вашей модели в конкретных задачах или помочь модели соответствовать определенным требованиям к выходным данным, когда одних инструкций недостаточно, а у вас есть набор примеров, демонстрирующих желаемые результаты. Модели на основе Gemma, такие как PaliGemma, требуют тонкой настройки для получения ожидаемых результатов.

Что находится в этом блокноте?

В этом блокноте используется эталонная реализация модели из big_vision и показано, как:

  • Установите необходимые зависимости и загрузите контрольную точку модели PaliGemma и обучающие данные.
  • Загрузите модель на графические процессоры (GPU).
  • Подготовьте входные данные для модели для обучения и вывода результатов.
  • Доработайте модель.
  • Проверьте результат.

Обучающие данные для этого ноутбука состоят из 90 пар изображений и длинных подписей к ним. Чтобы обеспечить его работу в среде выполнения T4 Colab, вам потребуется выполнить тонкую настройку только слоев внимания языковой модели, а остальные параметры зафиксировать.

Этот пример предназначен исключительно для обучения. В реальных условиях объем данных, обучаемые параметры, этапы обучения и гиперпараметры, а также полученные результаты могут существенно отличаться.

Прежде чем начать

Прежде чем приступить к работе с этим блокнотом, вам следует ознакомиться с кодом на Python, а также с тем, как обучаются большие языковые модели (LLM). Знание JAX не является обязательным, но базовые знания о JAX (или аналогичных технологиях, таких как Keras) будут полезны при чтении примеров кода.

Настраивать

В следующих разделах описаны предварительные шаги для использования модели PaliGemma в ноутбуке, включая доступ к модели, получение ключа API и настройку среды выполнения ноутбука.

Получите доступ к PaliGemma

Перед первым использованием PaliGemma необходимо запросить доступ к модели через Kaggle, выполнив следующие шаги:

  1. Войдите в Kaggle или создайте новую учетную запись Kaggle, если у вас ее еще нет.
  2. Перейдите к карточке модели PaliGemma и нажмите «Запросить доступ» .
  3. Заполните форму согласия и примите условия.

Настройте свой API-ключ

Для использования PaliGemma необходимо указать ваше имя пользователя Kaggle и ключ API Kaggle.

Чтобы сгенерировать ключ API Kaggle, откройте страницу настроек в Kaggle и нажмите «Создать новый токен» . Это запустит загрузку файла kaggle.json , содержащего ваши учетные данные API.

Затем в Colab выберите «Секреты » (🔑) в левой панели и добавьте свое имя пользователя Kaggle и ключ API Kaggle. Сохраните свое имя пользователя под именем KAGGLE_USERNAME , а ключ API — под именем KAGGLE_KEY .

Выберите среду выполнения

Для выполнения этого руководства вам потребуется среда выполнения Colab с достаточными ресурсами для запуска модели PaliGemma. В данном случае вы можете использовать графический процессор T4:

  1. В правом верхнем углу окна Colab щелкните раскрывающееся меню ▾ (Дополнительные параметры подключения) .
  2. Выберите «Изменить тип среды выполнения» .
  3. В разделе «Аппаратный ускоритель» выберите графический процессор T4 .

Установите пакеты Python.

Запустите указанную ниже ячейку, чтобы установить KaggleHub.

pip install -U -q kagglehub

Установите переменные среды

Установите переменные среды и войдите в Kaggle.

import os
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

# This will prompt you to enter your Kaggle API token
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
Kaggle credentials set.
Kaggle credentials successfully validated.

Загрузите репозиторий big_vision из GitHub в свой блокнот Colab и установите зависимости, связанные с big_vision , выполнив следующий код.

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 2.8 MB/s eta 0:00:00

Импортируйте JAX и другие зависимости.

Для работы PaliGemma необходимо импортировать JAX и другие зависимости, такие как TensorFlow и NumPy.

import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece

from IPython.core.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")
JAX version:  0.7.2
JAX platform: gpu
JAX devices:  1

Загрузите и настройте модель.

На этом этапе вы загрузите контрольную точку модели и настроите её, чтобы позже можно было выполнить её тонкую настройку. На этом этапе показано, как переместить параметры модели в память TPU, что полезно для тонкой настройки моделей на устройствах с ограниченными ресурсами.

Загрузите контрольную точку модели.

PaliGemma включает в себя несколько вариантов моделей. Для этого урока вы будете использовать базовую модель весов JAX/FLAX PaliGemma 3B .

Загрузите контрольную точку модели с Kaggle, выполнив следующий код. Этот процесс займет несколько минут.

import os
import kagglehub

# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224"  # Path to fetch from Kaggle.

# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"

if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
  print("Downloading the dataset...")
  !gsutil -m -q cp -n -r gs://longcap100/ .
  print(f"Data path: {DATA_DIR}")
Downloading the checkpoint from Kaggle, this could take a few minutes....
Downloading to /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz...
100%|██████████| 5.65G/5.65G [00:54<00:00, 112MB/s]
Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz
Downloading the model tokenizer...
Copying gs://big_vision/paligemma_tokenizer.model...

- [1 files][  4.1 MiB/  4.1 MiB]                                                
Operation completed over 1 objects/4.1 MiB.                                      
Tokenizer path: ./paligemma_tokenizer.model
Downloading the dataset...
Data path: ./longcap100

Настройте модель

Пришло время приступить к настройке модели, которую вы собираетесь использовать.

Для этого ноутбука вам потребуется разместить вашу модель на графическом процессоре T4. Ограниченные ресурсы, такие как пространство, означают, что вам необходимо тщательно продумать конфигурацию вашей модели.

Если вы будете точно настраивать каждый параметр, ваша модель не сможет работать в среде ноутбука. Поэтому в этой части ноутбука вы настроите свою модель таким образом, чтобы она могла «заморозить» некоторые параметры и точно настраивать только те параметры, которые действительно необходимы для получения точных результатов. В LLM-моделях параметры считаются «замороженными» , когда они больше не используются активно для обучения модели.

Для настройки вашей модели вам необходимо:

  • Инициализируйте model_config как FrozenConfigDict , чтобы можно было заморозить некоторые параметры и снизить потребление памяти.
  • Инициализируйте экземпляр класса PaliGemma Model , используя параметр model_config в качестве конфигурации.
  • Загрузите параметры модели в оперативную память.
  • Определите функцию decode для выборки выходных данных из модели.

Выполнение кода в этой ячейке занимает около минуты.

# Define model

# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

Переместите параметры модели в память GPU/TPU.

Теперь необходимо переместить параметры модели в память GPU/TPU. Сначала распределите параметры между доступными графическими процессорами, затем загрузите их. В данном случае параметры будут загружаться последовательно. Этот процесс занимает больше времени, чем одновременная загрузка, но требует больше оперативной памяти, чем доступно в этом ноутбуке.

Наконец, выведите все параметры, чтобы увидеть, к какому типу преобразован каждый отдельный параметр. Замороженные параметры сохраняются как float16 , а обучаемые параметры — как float32 . При просмотре списка вы увидите, что большинство параметров заморожены и имеют float16 .

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  # Cast others to float16, since some GPUs don't support bf16.
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)
== Model params == 
img/Transformer/encoder_norm/bias                                                (1152,)                float16
img/Transformer/encoder_norm/scale                                               (1152,)                float16
img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel                           (27, 4304, 1152)       float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias             (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel           (27, 1152, 16, 72)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias             (27, 1152)             float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel           (27, 16, 72, 1152)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias           (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel         (27, 1152, 16, 72)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias           (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel         (27, 1152, 16, 72)     float16
img/embedding/bias                                                               (1152,)                float16
img/embedding/kernel                                                             (14, 14, 3, 1152)      float16
img/head/bias                                                                    (2304,)                float16
img/head/kernel                                                                  (1152, 2304)           float16
img/pos_embedding                                                                (1, 256, 1152)         float16
llm/embedder/input_embedding                                                     (257152, 2304)         float16
llm/final_norm/scale                                                             (2304,)                float16
llm/layers/attn/attn_vec_einsum/w                                                (26, 8, 256, 2304)     float32
llm/layers/attn/kv_einsum/w                                                      (26, 2, 4, 2304, 256)  float32
llm/layers/attn/q_einsum/w                                                       (26, 8, 2304, 256)     float32
llm/layers/mlp/gating_einsum                                                     (26, 2, 2304, 9216)    float16
llm/layers/mlp/linear                                                            (26, 9216, 2304)       float16
llm/layers/post_attention_norm/scale                                             (26, 2304)             float16
llm/layers/post_ffw_norm/scale                                                   (26, 2304)             float16
llm/layers/pre_attention_norm/scale                                              (26, 2304)             float16
llm/layers/pre_ffw_norm/scale                                                    (26, 2304)             float16

Приготовьтесь к настройке модели.

Теперь, когда ваша модель настроена, вы можете её оптимизировать. На этом этапе вы создадите входные данные для вашей модели, а также итераторы для обучения и проверки, просмотрите примеры обучения и определите циклы обучения и проверки.

Создание входных данных модели

Используемая вами контрольная точка модели уже обучена на изображениях с различными соотношениями сторон, уменьшенных до размера 224x224 пикселей, и способна обрабатывать токенизированный текст.

Приведённый ниже код определяет три функции, которые вы будете использовать на следующем шаге для создания входных данных модели:

  • preprocess_image : Нормализует данные изображения. В данном случае предварительная обработка преобразует переданное изображение в оттенки серого, удаляет альфа-канал и изменяет размер переданного изображения до размера, необходимого модели для входных изображений (224x224 пикселей).
  • preprocess_tokens : Разделяет токены и добавляет флаги, указывающие, является ли токен префиксным или суффиксным. Эти флаги будут использоваться позже в коде, на этапе обучения и в цикле оценки.
  • postprocess_tokens : Удаляет все токены, оставшиеся на уровне или после токена конца последовательности (EOS), и возвращает оставшиеся декодированные токены.
def preprocess_image(image, size=224):
  # Model has been trained to handle images of different aspects ratios
  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
  # options are helpful to improve quality in some tasks.
  image = np.asarray(image)
  if image.ndim == 2:  # Convert image without last channel into greyscale.
    image = np.stack((image,)*3, axis=-1)
  image = image[..., :3]  # Remove alpha layer.
  assert image.shape[-1] == 3

  image = tf.constant(image)
  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

def preprocess_tokens(prefix, suffix=None, seqlen=None):
  # Model has been trained to handle tokenized text composed of a prefix with
  # full attention and a suffix with causal attention.
  separator = "\n"
  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

  if suffix:
    suffix = tokenizer.encode(suffix, add_eos=True)
    tokens += suffix
    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

  mask_input = [1] * len(tokens)    # 1 if it's a token, 0 if padding.
  if seqlen:
    padding = [0] * max(0, seqlen - len(tokens))
    tokens = tokens[:seqlen] + padding
    mask_ar = mask_ar[:seqlen] + padding
    mask_loss = mask_loss[:seqlen] + padding
    mask_input = mask_input[:seqlen] + padding

  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

def postprocess_tokens(tokens):
  tokens = tokens.tolist()  # np.array to list[int]
  try:  # Remove tokens at and after EOS if any.
    eos_pos = tokens.index(tokenizer.eos_id())
    tokens = tokens[:eos_pos]
  except ValueError:
    pass
  return tokenizer.decode(tokens)

Создайте итераторы для обучения и проверки.

Создайте два итератора:

  • Итератор для обучения , позволяющий процессу обучения обрабатывать данные по частям, а не все сразу.
    • Это позволяет выполнить предварительную обработку данных перед использованием.
  • Итератор валидации , позволяющий процессу обучения итерировать по набору данных валидации, чтобы увидеть, насколько хорошо настроенная модель соответствует предоставленным результатам.
SEQLEN = 128

train_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_train90.jsonl"),
    fopen_keys={"image": DATA_DIR})

val_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_val10.jsonl"),
    fopen_keys={"image": DATA_DIR})


def train_data_iterator():
  """Never ending iterator over training examples."""
  # Shuffle examples and repeat so one can train for many epochs.
  dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
  for example in dataset.as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    suffix = example["suffix"].decode().lower()
    tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_loss": np.asarray(mask_loss),
    }


def validation_data_iterator():
  """Single iterator over validation examples."""
  for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_input": np.asarray(mask_input),
    }

Посмотреть примеры обучения

В этом блокноте обучающие данные содержат 90 изображений, каждое из которых сопровождается подробным описанием того, что изображено на картинке.

Приведённый ниже код выводит случайную выборку изображений с их описаниями из обучающего набора данных, чтобы вы могли увидеть, как выглядят изображения и описания, на которых обучалась ваша модель. Каждое изображение отображается в формате JPEG размером 128x128 пикселей, а описание выводится рядом с изображением справа.

def render_inline(image, resize=(128, 128)):
  """Convert image into inline html."""
  image = Image.fromarray(image)
  image.resize(resize)
  with io.BytesIO() as buffer:
    image.save(buffer, format='jpeg')
    image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
    return f"data:image/jpeg;base64,{image_b64}"

def render_example(image, caption):
  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]
  return f"""
    <div style="display: inline-flex; align-items: center; justify-content: center;">
        <img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
        <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
    </div>
    """

html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
  caption = postprocess_tokens(example["text"])  # detokenize model input.
  caption = caption[len("caption en\n"):]        # strip prefix
  html_out += render_example(example["image"], caption)

print("Training examples")
display(HTML(html_out))
Training examples

Определите циклы обучения и оценки.

Определите цикл обучения для тренировки модели на предоставленном наборе данных и цикл оценки для анализа всех примеров в проверочном наборе данных и выполнения прогнозов.

Определение цикла обучения

Функция update_fn определяет этап обучения. На этапе обучения вычисляется функция потерь для каждого примера, и к обучаемым параметрам применяется стохастический градиентный спуск (SGD).

Напомним, что ранее в блокноте вы добавили в функцию preprocess_tokens флаги, включающие mask_loss . Здесь вы будете использовать флаг mask_loss , чтобы исключить префиксные и дополненные токены из функции потерь. Без него расчет потерь будет искажен. Вам также необходимо нормализовать каждый пример, поскольку каждый из них имеет разное количество токенов. После исключения префиксных и дополненных токенов и нормализации примеров вы можете рассчитать потери для каждого примера.

Этап обучения также включает функцию применения алгоритма стохастического градиентного спуска (SGD) для оптимизации процесса обучения.

Определение цикла оценки

Функция make_predictions — это цикл оценки. Цикл оценки довольно прост, за исключением одного существенного изменения. Как вы помните из начала блокнота, в вашем обучающем наборе данных всего 90 примеров. Это очень небольшое количество обучающих примеров, и в итоге вашей модели не хватает примеров для размера пакета при запуске обучения. Это означает, что в цикле оценки вам необходимо дополнить пакет повторяющимися примерами.

Чтобы гарантировать, что ваш цикл оценки будет учитывать только фактические примеры, а не примеры с заполненными данными, необходимо применить маску к примерам с заполненными данными, которая исключит их из выходных данных.

# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens, normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  # Apply gradients to trainable params using SGD.
  def apply_grad(param, gradient, trainable):
    if not trainable: return param
    return param - learning_rate * gradient

  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)

  return params, loss

# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
  outputs = []
  while True:
    # Construct a list of examples in the batch.
    examples = []
    try:
      for _ in range(batch_size):
        examples.append(next(data_iterator))
        examples[-1]["_mask"] = np.array(True)  # Indicates true example.
    except StopIteration:
      if len(examples) == 0:
        return outputs

    # Not enough examples to complete a batch. Pad by repeating last example.
    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

    # Convert list of examples into a dict of np.arrays and load onto devices.
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    # Make model predictions
    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)

    # Fetch model predictions to device and detokenize.
    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]  # remove padding examples.
    responses = [postprocess_tokens(t) for t in tokens]

    # Append to html output.
    for example, response in zip(examples, responses):
      outputs.append((example["image"], response))
      if num_examples and len(outputs) >= num_examples:
        return outputs

Настройте модель

Теперь, когда вы все настроили и ознакомились с обучающими данными, пришло время наконец настроить модель. Приведенный ниже код запускает цикл обучения модели на 64 шага и выводит скорость обучения ( lr в выходных данных) и коэффициент потерь для каждого шага.

Каждые 16 шагов модель выводит свои предсказания на этом этапе обучения. Этот код выводит предсказания для одного и того же набора изображений, чтобы вы могли видеть, как способность модели предсказывать описания улучшается со временем.

На ранних этапах обучения могут возникать проблемы с описаниями, например, повторяющиеся предложения, когда модель застревает в цикле прогнозирования, или незаконченные предложения. По мере обучения точность прогнозов модели неуклонно повышается. К 64-му шагу прогнозы модели должны максимально точно соответствовать описаниям, предоставленным обучающими данными.

На обработку T4 TPU этот процесс занимает около 15 минут.

# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time

BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4

train_data_it = train_data_iterator()

sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)

for step in range(1, TRAIN_STEPS+1):
  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]

  # Convert list of examples into a dict of np.arrays and load onto devices.
  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  batch = big_vision.utils.reshard(batch, data_sharding)

  # Training step and report training loss
  learning_rate = sched_fn(step)
  params, loss = update_fn(params, batch, learning_rate)

  loss = jax.device_get(loss)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")

  if (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    html_out = ""
    for image, caption in make_predictions(
        validation_data_iterator(), num_examples=4, batch_size=4):
      html_out += render_example(image, caption)
    display(HTML(html_out))
step:  1/64   lr: 0.00500   loss: 3.6567
step:  2/64   lr: 0.01000   loss: 1.9762
step:  3/64   lr: 0.01500   loss: 1.6299
step:  4/64   lr: 0.02000   loss: 1.5651
step:  5/64   lr: 0.02500   loss: 1.9813
step:  6/64   lr: 0.03000   loss: 1.9996
step:  7/64   lr: 0.02998   loss: 1.8595
step:  8/64   lr: 0.02992   loss: 1.6479
step:  9/64   lr: 0.02981   loss: 1.3693
step: 10/64   lr: 0.02966   loss: 1.3423
step: 11/64   lr: 0.02947   loss: 1.2122
step: 12/64   lr: 0.02924   loss: 1.0602
step: 13/64   lr: 0.02897   loss: 1.1314
step: 14/64   lr: 0.02866   loss: 1.2612
step: 15/64   lr: 0.02831   loss: 1.0132
step: 16/64   lr: 0.02792   loss: 1.2126
Model predictions at step 16
step: 17/64   lr: 0.02750   loss: 1.0986
step: 18/64   lr: 0.02704   loss: 0.9461
step: 19/64   lr: 0.02655   loss: 1.2098
step: 20/64   lr: 0.02602   loss: 1.0513
step: 21/64   lr: 0.02546   loss: 1.0979
step: 22/64   lr: 0.02488   loss: 0.9739
step: 23/64   lr: 0.02426   loss: 0.9589
step: 24/64   lr: 0.02362   loss: 0.7053
step: 25/64   lr: 0.02296   loss: 0.7347
step: 26/64   lr: 0.02227   loss: 0.6990
step: 27/64   lr: 0.02156   loss: 0.6736
step: 28/64   lr: 0.02083   loss: 0.6642
step: 29/64   lr: 0.02009   loss: 0.6908
step: 30/64   lr: 0.01933   loss: 0.7257
step: 31/64   lr: 0.01856   loss: 0.6902
step: 32/64   lr: 0.01778   loss: 0.7054
Model predictions at step 32
step: 33/64   lr: 0.01699   loss: 0.7709
step: 34/64   lr: 0.01620   loss: 0.6653
step: 35/64   lr: 0.01540   loss: 0.3811
step: 36/64   lr: 0.01460   loss: 0.3104
step: 37/64   lr: 0.01380   loss: 0.4042
step: 38/64   lr: 0.01301   loss: 0.3904
step: 39/64   lr: 0.01222   loss: 0.3339
step: 40/64   lr: 0.01144   loss: 0.4156
step: 41/64   lr: 0.01067   loss: 0.4085
step: 42/64   lr: 0.00991   loss: 0.3083
step: 43/64   lr: 0.00917   loss: 0.3757
step: 44/64   lr: 0.00844   loss: 0.3813
step: 45/64   lr: 0.00773   loss: 0.3381
step: 46/64   lr: 0.00704   loss: 0.2057
step: 47/64   lr: 0.00638   loss: 0.1287
step: 48/64   lr: 0.00574   loss: 0.1711
Model predictions at step 48
step: 49/64   lr: 0.00512   loss: 0.1183
step: 50/64   lr: 0.00454   loss: 0.1154
step: 51/64   lr: 0.00398   loss: 0.1967
step: 52/64   lr: 0.00345   loss: 0.1497
step: 53/64   lr: 0.00296   loss: 0.1688
step: 54/64   lr: 0.00250   loss: 0.1878
step: 55/64   lr: 0.00208   loss: 0.1865
step: 56/64   lr: 0.00169   loss: 0.1655
step: 57/64   lr: 0.00134   loss: 0.0911
step: 58/64   lr: 0.00103   loss: 0.1836
step: 59/64   lr: 0.00076   loss: 0.1242
step: 60/64   lr: 0.00053   loss: 0.0814
step: 61/64   lr: 0.00034   loss: 0.0866
step: 62/64   lr: 0.00019   loss: 0.1295
step: 63/64   lr: 0.00008   loss: 0.1053
step: 64/64   lr: 0.00002   loss: 0.0730
Model predictions at step 64
CPU times: user 2min 18s, sys: 8.98 s, total: 2min 27s
Wall time: 15min 45s

Выход

В этом блокноте для проверки используются всего 10 изображений. В обычном коде у вас, вероятно, было бы гораздо больше точек данных для проверки, но для этого блокнота выполните следующий код, чтобы сгенерировать описания для всех 10 изображений. После настройки модели эти описания должны быть очень похожи по форме и содержанию на описания, включенные в обучающие данные, которые вы рассматривали ранее в этом блокноте.

Выполните приведенный ниже код, чтобы сгенерировать описания для набора данных для проверки.

# The validation data consists of 10 images in a different domain than training
# data.
%%time

print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))
Model predictions
CPU times: user 1.87 s, sys: 283 ms, total: 2.15 s
Wall time: 39.3 s