View on ai.google.dev | Run in Google Colab | View source on GitHub |
This notebook shows how to fine-tune PaliGemma on a vision-language task with JAX. Fine-tuning is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.
What's in this notebook
This notebook uses the model reference implementation from big_vision
and shows how to:
- Install dependencies, and download the PaliGemma model checkpoint and training data
- Load the model onto GPU devices
- Prepare the model's inputs for training and inference
- Fine-tune the model
- Inspect the output
The training data for this notebook consists of 90 pairs of images and long captions describing them. To make it runnable on a T4 colab runtime, you'll only fine-tune the attention layers of the language model and freeze the other parameters.
This example is for learning purposes only. In a real use case, the amount of data, trainable parameters, training steps and hyper-parameters, and obtained results could be significantly different.
Before you begin
Before going through this notebook, you should be familiar with Python code, as well as how large language models (LLMs) are trained. You don't need to be familiar with JAX, but basic knowledge about JAX (or similar technologies such as Keras) is helpful when reading through the example code.
Setup
The following sections explain the preliminary steps for getting a notebook to use a PaliGemma model, including model access, getting an API key, and configuring the notebook runtime.
Get access to PaliGemma
Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:
- Log in to Kaggle, or create a new Kaggle account if you don't already have one.
- Go to the PaliGemma model card and click Request Access.
- Complete the consent form and accept the terms and conditions.
Configure your API key
To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.
To generate a Kaggle API key, open your Settings page in Kaggle and click Create New Token. This triggers the download of a kaggle.json
file containing your API credentials.
Then, in Colab, select Secrets (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name KAGGLE_USERNAME
and your API key under the name KAGGLE_KEY
.
Select the runtime
To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the PaliGemma model. In this case, you can use a T4 GPU:
- In the upper-right of the Colab window, click the ▾ (Additional connection options) dropdown menu.
- Select Change runtime type.
- Under Hardware accelerator, select T4 GPU.
Set environment variables
Set the environment variables for KAGGLE_USERNAME
and KAGGLE_KEY
.
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
Fetch the big_vision
repository and install related dependencies
Download the big_vision
repository to your Colab notebook from GitHub and install dependencies related to big_vision
by running the following code.
import os
import sys
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
!git clone --quiet --branch=main --depth=1 \
https://github.com/google-research/big_vision big_vision_repo
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
Import JAX and other dependencies
Import JAX and other dependencies required for PaliGemma, like TensorFlow and NumPy.
import base64
import functools
import html
import io
import os
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from IPython.core.display import display, HTML
from PIL import Image
# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.extend.backend.get_backend()
print(f"JAX version: {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices: {jax.device_count()}")
Download and configure the model
In this step, you'll download the model checkpoint and configure it so that you can fine-tune it later on. This step shows you how to move model parameters into TPU memory, which is useful for fine-tuning models on devices with limited resources.
Download the model checkpoint
PaliGemma includes several model variations. For this tutorial, you'll use the base JAX/FLAX PaliGemma 3B weight model.
Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete.
import os
import kagglehub
# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224" # Path to fetch from Kaggle.
# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"
if not os.path.exists(MODEL_PATH):
print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
print(f"Model path: {MODEL_PATH}")
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
print("Downloading the model tokenizer...")
!gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
print(f"Tokenizer path: {TOKENIZER_PATH}")
DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
print("Downloading the dataset...")
!gsutil -m -q cp -n -r gs://longcap100/ .
print(f"Data path: {DATA_DIR}")
Configure the model
It's time to actually start configuring the model that you're going to use.
For this notebook, you need to be able to fit your model onto a T4 GPU. Having a limited resource like space constraints means that you have to be mindful of how your model is configured.
If you fine-tune every parameter, your model won't be able to run in the notebook environment. As a result, in this part of the notebook, you'll configure your model so that it has the ability to freeze some of the parameters, and only fine-tune the parameters that really need to be fine-tuned for the model to give you accurate results. In LLMs, parameters are said to be frozen when they are no longer actively being used to train the model.
In order to configure your model, you need to:
- Initialize the
model_config
as aFrozenConfigDict
so that you can freeze some of the parameters and keep memory usage low - Initialize an instance of the PaliGemma
Model
class using themodel_config
as its configurations - Load the model parameters into RAM
- Define a
decode
function to sample outputs from the model
This code in this cell takes about a minute to run to completion.
# Define model
# IMPORTANT: Gemma-2 has a "final_logits_softcap" property, we set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())
Move model parameters into GPU/TPU memory
Now you need to move the model parameters into GPU/TPU memory. First, shard the parameters across the available GPUs, then load the parameters. Here, you'll load the parameters sequentially. This process takes longer than loading them simultaneously, but it requires more RAM than you have available in this notebook.
Finally, print out all of the parameters to see what type each individual parameter is cast to. Frozen parameters are kept as float16
, while the trainable parameters are cast to float32
. When you inspect the list, you'll see that most of the parameters have been frozen and are float16
.
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param): # pylint: disable=unused-argument
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
# Cast others to float16, since some GPUs don't support bf16.
return jax.tree.map(lambda p, m: p.astype(jnp.float32)
if m else p.astype(jnp.float16),
params, trainable)
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
params[idx] = big_vision.utils.reshard(params[idx], sharding)
params[idx] = maybe_cast_to_f32(params[idx], trainable)
params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# Print params to show what the model is made of.
def parameter_overview(params):
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)
Prepare to tune the model
Now that your model is configured, you can tune it. In this step, you'll create your model's inputs as well as the training and validation iterators, view the training examples, and define the training and validation loops.
Create model inputs
The model checkpoint you're using has already been trained on images of various aspect ratios that have been resized to 224x224 pixels, and to handle tokenized texts.
The code below defines three functions that you'll use in the next step create the model's inputs:
preprocess_image
: Normalizes the image data. In this case, pre-processing converts the passed-in image to greyscale, removes the alpha layer, and resizes the passed-in image to the size required by the model for image inputs (224x224 pixels).preprocess_tokens
: Splits the tokens up and adds flags to mark whether a token is a prefix or suffix token. These flags will be used later on in the code, during the training step and the evaluation loop.postprocess_tokens
: Removes any tokens left at and/or after the end-of-sequence (EOS) token and returns the remaining decoded tokens.
def preprocess_image(image, size=224):
# Model has been trained to handle images of different aspects ratios
# resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
# options are helpful to improve quality in some tasks.
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]
def preprocess_tokens(prefix, suffix=None, seqlen=None):
# Model has been trained to handle tokenized text composed of a prefix with
# full attention and a suffix with causal attention.
separator = "\n"
tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if it's a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return tokenizer.decode(tokens)
Create the training and validation iterators
Create two iterators:
- A training iterator to allow the training process to go through the data in chunks rather than processing it all at once
- This allows you to do some data pre-processing before use
- A validation iterator that allows the training process to iterate over the validation dataset to see how well the tuned model aligned with the provided results
SEQLEN = 128
train_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_train90.jsonl"),
fopen_keys={"image": DATA_DIR})
val_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_val10.jsonl"),
fopen_keys={"image": DATA_DIR})
def train_data_iterator():
"""Never ending iterator over training examples."""
# Shuffle examples and repeat so one can train for many epochs.
dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
for example in dataset.as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
suffix = example["suffix"].decode().lower()
tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_loss": np.asarray(mask_loss),
}
def validation_data_iterator():
"""Single iterator over validation examples."""
for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input),
}
View training examples
In this notebook, the training data contains 90 images that are paired with long descriptions of what's depicted in the image.
The code below prints a random selection of images with their descriptions from the training data set so that you can see what the images and descriptions your model is trained on looks like. Each image is displayed in as a 128x128 pixel JPEG, with the description printed next to the image to the right.
def render_inline(image, resize=(128, 128)):
"""Convert image into inline html."""
image = Image.fromarray(image)
image.resize(resize)
with io.BytesIO() as buffer:
image.save(buffer, format='jpeg')
image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
return f"data:image/jpeg;base64,{image_b64}"
def render_example(image, caption):
image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]
return f"""
<div style="display: inline-flex; align-items: center; justify-content: center;">
<img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
<p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
</div>
"""
html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
caption = postprocess_tokens(example["text"]) # detokenize model input.
caption = caption[len("caption en\n"):] # strip prefix
html_out += render_example(example["image"], caption)
print("Training examples")
display(HTML(html_out))
Define the training and evaluation loops
Define the training loop to train the model on the provided dataset, and the evaluation loop to look at all of the examples in the validation dataset and make its predictions.
Defining the training loop
The update_fn
function defines the training step. During the training step, the loss per example is calculated and stochastic gradient descent (SGD) is applied to the trainable parameters.
Recall that earlier in the notebook, you included flags in the preprocess_tokens
function that included mask_loss
. You'll use the mask_loss
flag here to exclude prefix and padded tokens from the loss. Without it, the loss calculation will be skewed. You also need to normalize each example, since each of them has a different number of tokens. After the prefix and padded tokens have been excluded and the examples have been normalized, you can calculate the loss per example.
The training step also includes a function to apply an SGD to optimize the training.
Defining the evaluation loop
The make_predictions
function is your evaluation loop. The evaluation loop is fairly straight forward with one notable change. If you recall from the beginning of the notebook, you only have 90 examples in your training data set. This is a very small amount of training examples, and your model ends up not having enough examples for the batch size when you run the training. This means that in the evaluation loop, you need to pad the batch by repeating examples.
To make sure that your evaluation loop only counts actual examples and not the padded examples, you have to apply a mask to the padded examples that excludes them from the output.
# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]
def loss_fn(params):
text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
logp = jax.nn.log_softmax(text_logits, axis=-1)
# The model takes as input txts[:, :-1] but the loss is defined as predicting
# next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
# are part of the loss (e.g. prefix and padded tokens are not included).
mask_loss = batch["mask_loss"][:, 1:]
targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])
# Compute the loss per example. i.e. the mean of per token pplx.
# Since each example has a different number of tokens we normalize it.
token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.
example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.
example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.
# batch_loss: mean of per example loss.
return jnp.mean(example_loss)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Apply gradients to trainable params using SGD.
def apply_grad(param, gradient, trainable):
if not trainable: return param
return param - learning_rate * gradient
params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)
return params, loss
# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
# Construct a list of examples in the batch.
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
# Not enough examples to complete a batch. Pad by repeating last example.
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Make model predictions
tokens = decode({"params": params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [postprocess_tokens(t) for t in tokens]
# Append to html output.
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
Tune the model
Now that you've set everything up and taken a look at the training data, it's time to finally tune the model. The code below runs the training loop for the model for 64 steps and prints the learning rate (lr
in the printed output) and loss rate for each step.
Every 16 steps, the model prints what its predictions are at that step in the training. This code prints out predictions for the same set of images so that you can see the model's ability to predict descriptions improve over time.
At earlier steps in the training, there's likely issues with the descriptions, such as repeated sentences as the model gets stuck in its predictive loop or unfinished sentences. The model's predictions become steadily more accurate as training progresses. By step 64, the model's predictions should closely resemble the descriptions provided by the training data.
This process takes around 15 minutes to complete on T4 TPUs.
# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time
BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4
train_data_it = train_data_iterator()
sched_fn = big_vision.utils.create_learning_rate_schedule(
total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
decay_type="cosine", warmup_percent=0.10)
for step in range(1, TRAIN_STEPS+1):
# Make list of N training examples.
examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Training step and report training loss
learning_rate = sched_fn(step)
params, loss = update_fn(params, batch, learning_rate)
loss = jax.device_get(loss)
print(f"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}")
if (step % EVAL_STEPS) == 0:
print(f"Model predictions at step {step}")
html_out = ""
for image, caption in make_predictions(
validation_data_iterator(), num_examples=4, batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
Output
The validation data for this notebook consists of just 10 images. In normal code, you would likely have many more data points for validation, but for this notebook, run the following code to generate descriptions for all 10 images. After tuning the model, these descriptions should be very similar in form and content coverage to the descriptions included with the training data that you looked at earlier in this notebook.
Run the below code to generate descriptions for the validation data set.
# The validation data consists of 10 images in a different domain than training
# data.
%%time
print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))