Fine-tuning BLIP using PEFT
BLIP is a good model for image captioning. Has a good architecture for this task. We can fine-tune this model to have it learn domain specific captioning.
PEFT
Hugging face has a PEFT library which allows us to hook into other models and capture Linear or Conv2D layers. We can use this to fine-tune the BLIP model with the attention and convolution layers.
Dependencies and environment
python -m venv venv
source ./venv/bin/activate # linux
call ./venv/scripts/Activate.bat  # windows?
Install our dependencies
pip install transformers peft datasets
Install PyTorch
Use the PyTorch instructions for your machine: Get started — PyTorch
Setting up our PEFT and BLIP model
from transformers import AutoModelForVision2Seq, AutoProcessor
from peft import LoraConfig, get_peft_model
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    # target specific modules (look for Linear in the model)
    # print(model) to see the architecture of the model
    target_modules=[
        "self.query",
        "self.key",
        "self.value",
        "output.dense",
        "self_attn.qkv",
        "self_attn.projection",
        "mlp.fc1",
        "mlp.fc2",
    ],
)
# loading the model in float16
model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16)
# Load the blip model processor (just sets up the associated image processor so no additional model)
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# We layer our PEFT on top of our model using the PEFT config
model = get_peft_model(model, config)
model.print_trainable_parameters()
Load the dataset
Using dataset we load a sample dataset but you could load your own image/caption pairs here. See huggingface.co datasets for more info.
from datasets import load_dataset
# We are extracting the train dataset
dataset = load_dataset("ybelkada/football-dataset", split="train")
# Sample the dataset
print("Dataset images: ", len(dataset))
print(next(iter(dataset)))
Setup data loader and collator
The ImageCaptioningDataset processes the dataset with the processor to get the
encoding to pass into the model. Dataloader batches our dataset and collator
aligns the batch with the dataset.
from torch.utils.data import DataLoader, Dataset
# Each dataset may have different names and format and this function should be
# adjusted for each dataset. Creating `image` and `text` keys to correlate with
# the collator
class FootballImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(
            images=item["image"], padding="max_length", return_tensors="pt"
        )
        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding["text"] = item["text"]
        return encoding
def collator(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key != "text":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            text_inputs = processor.tokenizer(
                [example["text"] for example in batch], padding=True, return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
    return processed_batch
train_dataset = FootballImageCaptioningDataset(dataset, processor)
# Set batch_size to the batch that works for you.
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=collator)
Optimizer
import torch.optim as optim
# Setup AdamW
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
Finish setting up for training
# Set the device which works for you or use these defaults
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set the model as ready for training, makes sure the gradient are on
model.train()
The training loop
Here we’ll train for 50 epochs. Adjust to how many you may need for your training dataset. We have set up our batches using the dataloader, so we’ll take out our batch and then send our batch to our model. Moving the input onto the training device.
Every 10 steps we’ll test a generation of caption from our training image.
# 50 is the number of epochs. adjust to how many you may need
for epoch in range(50):
    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)
        attention_mask = batch.pop("attention_mask").to(device)
        outputs = model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            labels=input_ids,
            attention_mask=attention_mask,
        )
        loss = outputs.loss
        print("Loss:", loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if idx % 10 == 0:
            generated_output = model.generate(pixel_values=pixel_values)
            print(processor.batch_decode(generated_output, skip_special_tokens=True))
Save
# hugging face models saved to the directory
model.save_pretrained('./training/caption')
Sample output
trainable params: 5,455,872 || all params: 252,869,948 || trainable%: 2.1575802277619798
6
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=722x360 at 0x759E0338B350>, 'text': "Benzema after Real Mardid's win against PSG"}
Epoch: 0
Loss: 5.096549034118652
['ronaldo looking on during the match', "lionel, argentina's player, is seen during the match between argentina and argentina"]
Loss: 6.560091495513916
Loss: 5.695440292358398
Epoch: 1
Loss: 2.965883731842041
['ronaldo with portugal at the 2018 world cup', 'zida with france in 2006']
Loss: 4.138251304626465
Loss: 4.490686416625977
Epoch: 2
Loss: 2.501580238342285
['messi with argentina at the 2018 fifa world cup', 'maradona with argentina in 1986 world cup with argentina']
Loss: 3.6341114044189453
Loss: 3.1389389038085938
...
Epoch: 49
Loss: 1.3841571807861328
['ronaldo with portugal at the 2018 world cup', 'zidane with france in 2006 world cup']
Loss: 1.6498689651489258
Loss: 2.872819423675537
Saving to ./training/caption
Inference
Then we can load it for inference. Note we use an image from the web so download into the current directory.
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
# The model here is to get the appropriate processor
model_id = "Salesforce/blip-image-captioning-base"
processor = AutoProcessor.from_pretrained(model_id)
# Caption PEFT model has the metadata for the appropriate base model. So we can just load
# that model directly
model = AutoModelForVision2Seq.from_pretrained('./training/caption')
# Turn on evaluation mode so we don't calculate gradients
model.eval()
# https://upload.wikimedia.org/wikipedia/commons/5/52/Lionel_Messi_playing_in_Argentina_2022_FIFA_World_Cup.jpg
# Set the images to be to captioned
images = [Image.open("Lionel_Messi_playing_in_Argentina_2022_FIFA_World_Cup.jpg")]
processed = processor(images=images, padding="max_length", return_tensors="pt")
generated_output = model.generate(pixel_values=processed['pixel_values'], max_new_tokens=64)
print(processor.batch_decode(generated_output, skip_special_tokens=True))
$ python inference2.py
['messi with argentina at the 2018 world cup']
Full script
Available as a full script for the training: