Classify an incomplete vision–language dataset (Oxford‑IIIT Pets) with deep learning

This tutorial demonstrates how to classify samples from an incomplete vision–language dataset using the iMML library. iMML supports robust classification even when some modalities (e.g., text or image) are missing, making it suitable for real‑world multi‑modal data where missingness is common.

We will use the RAGPT algorithm from the iMML classify module on the Oxford‑IIIT Pets dataset and evaluate its performance.

What you will learn:

  • How to load a public vision–language dataset (Oxford‑IIIT Pets via Hugging Face Datasets).

  • How to adapt this workflow to your own vision–language data.

  • How to build a retrieval‑augmented memory bank and prompts with MCR.

  • How to train the RAGPT classifier when image or text may be missing.

  • How to track metrics during training and evaluate with MCC and a confusion matrix.

# sphinx_gallery_thumbnail_number = 1

# License: BSD 3-Clause License

Step 0: Prerequisites

To run this tutorial, install the extras for deep learning:

pip install imml[deep]

We also use the Hugging Face Datasets library to load Oxford‑IIIT Pets:

pip install datasets

Step 1: Import required libraries

import shutil
from PIL import Image
from lightning import Trainer
import lightning as L
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch
import os
import pandas as pd
from sklearn.metrics import matthews_corrcoef, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from datasets import load_dataset

from imml.ampute import Amputer
from imml.classify import RAGPT
from imml.load import RAGPTDataset, RAGPTCollator
from imml.retrieve import MCR

Step 2: Prepare the dataset

We use the oxford-iiit-pet-vl-enriched dataset, a public vision–language dataset with images and captions available on Hugging Face Datasets as visual-layer/oxford-iiit-pet-vl-enriched. For retrieval, we will use the MCR class from the retrieve module.

random_state = 42
L.seed_everything(random_state)

# Local working directory (images will be saved here so MCR can read paths)
data_folder = "oxford_iiit_pet"
folder_images = os.path.join(data_folder, "imgs")
os.makedirs(folder_images, exist_ok=True)

# Load the dataset
ds = load_dataset("visual-layer/oxford-iiit-pet-vl-enriched", split="train[:50]")

# Build a DataFrame with image paths and captions. We persist images to disk because
# the retriever expects paths.
n_total = len(ds)
rows = []
for i in range(n_total):
    ex = ds[i]
    img = ex.get("image", None)
    caption = ex.get("caption_enriched", None)
    label = ex.get("label_cat_dog", None)
    img_path = os.path.join(folder_images, f"{i:06d}.jpg")
    try:
        img.save(img_path)
    except Exception:
        img.convert("RGB").save(img_path)
    rows.append({"img": img_path, "text": caption, "label": label})

df = pd.DataFrame(rows)
le = LabelEncoder()
df["class"] = le.fit_transform(df["label"])
df["class"].value_counts()
Seed set to 42

Generating train split:   0%|                  | 0/3680 [00:00<?, ? examples/s]
Generating train split:  35%|█▍  | 1300/3680 [00:00<00:00, 10662.26 examples/s]
Generating train split: 100%|████| 3680/3680 [00:00<00:00, 17222.90 examples/s]

Generating test split:   0%|                   | 0/3669 [00:00<?, ? examples/s]
Generating test split:  79%|███▉ | 2900/3669 [00:00<00:00, 27793.54 examples/s]
Generating test split: 100%|█████| 3669/3669 [00:00<00:00, 28414.14 examples/s]

class
1    37
0    13
Name: count, dtype: int64

Split into 40% bank memory, 40% train and 20% test sets

train_df, test_df = train_test_split(df, test_size=0.2, shuffle=True, stratify=df["class"])
train_df, bank_df = train_test_split(train_df, test_size=0.5, shuffle=True, stratify=train_df["class"])
print("train_df", train_df.shape)
print("test_df", test_df.shape)
print("bank_df", bank_df.shape)
train_df.head()
train_df (20, 4)
test_df (10, 4)
bank_df (20, 4)
img text label class
27 oxford_iiit_pet/imgs/000027.jpg a large white dog standing on a patio near a b... dog 1
24 oxford_iiit_pet/imgs/000024.jpg a white cat laying on a floor cat 0
21 oxford_iiit_pet/imgs/000021.jpg a black and tan dog with a blue collar dog 1
44 oxford_iiit_pet/imgs/000044.jpg a cat yawning on the floor cat 0
19 oxford_iiit_pet/imgs/000019.jpg a dog laying on a red pillow dog 1


Step 3: Simulate missing modalities

To reflect realistic scenarios, we randomly introduce missing data using Amputer. In this case, 60% of training and test samples will have either text or image missing. You can change this parameter for more or less amount of incompleteness.

Xs_train = [train_df[["img"]], train_df[["text"]]]
Xs_test = [test_df[["img"]], test_df[["text"]]]
amputer = Amputer(p=0.6, random_state=random_state)
Xs_train = amputer.fit_transform(Xs_train)
Xs_test = amputer.fit_transform(Xs_test)

Step 4: Generate the prompts using a retriever

RAGPT needs prompts, which are created from a memory bank with a retriever. We use MCR (Multi-Channel Retriever) to construct a memory bank and generate prompts.

modalities = ["image", "text"]
batch_size = 64
estimator = MCR(batch_size=batch_size, modalities=modalities, save_memory_bank=True,
                prompt_path=data_folder, n_neighbors=2, generate_cap=True)

Xs_bank = [bank_df[["img"]], bank_df[["text"]]]
y_bank = bank_df["class"]

estimator.fit(Xs=Xs_bank, y=y_bank)
memory_bank = estimator.memory_bank_
print("memory_bank", memory_bank.shape)
memory_bank.head()
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
memory_bank (20, 8)
item_id img_path text q_i q_t label prompt_image_path prompt_text_path
18 18 oxford_iiit_pet/imgs/000018.jpg a gray cat laying on the floor [-0.2883501350879669, 0.6082456111907959, 0.25... [-0.520526111125946, -0.2758327126502991, 0.22... 0 oxford_iiit_pet/image/000018.npy oxford_iiit_pet/text/000018.npy
7 7 oxford_iiit_pet/imgs/000007.jpg a man holding a black dog [-0.36506450176239014, 0.2776166796684265, -0.... [-0.25834551453590393, 0.5495427846908569, 0.3... 1 oxford_iiit_pet/image/000007.npy oxford_iiit_pet/text/000007.npy
20 20 oxford_iiit_pet/imgs/000020.jpg a cat is sitting on a branch [-0.32218578457832336, -0.1820022314786911, 0.... [-0.8176320791244507, 0.08956064283847809, 0.7... 0 oxford_iiit_pet/image/000020.npy oxford_iiit_pet/text/000020.npy
0 0 oxford_iiit_pet/imgs/000000.jpg a cat walking on grass [0.0411289781332016, 0.28625422716140747, 0.22... [0.36405670642852783, 0.4739795923233032, 0.63... 0 oxford_iiit_pet/image/000000.npy oxford_iiit_pet/text/000000.npy
46 46 oxford_iiit_pet/imgs/000046.jpg a dog laying in the grass [0.13979408144950867, 0.3674620985984802, -0.4... [0.5348188877105713, 0.22137272357940674, 0.30... 1 oxford_iiit_pet/image/000046.npy oxford_iiit_pet/text/000046.npy


Load generated training and testing prompts.

train_db (20, 14)
test_db (10, 14)

Step 5: Training the model

Create the loaders.

train_data = RAGPTDataset(database=train_db)
train_dataloader = DataLoader(dataset= train_data, batch_size=batch_size,
                              collate_fn= RAGPTCollator(), shuffle=True)

test_data = RAGPTDataset(database=test_db)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size,
                             collate_fn=RAGPTCollator(), shuffle=False)

Train the RAGPT model using the generated prompts. For speed in this demo we train for only 2 epochs using the Lightning library.

trainer = Trainer(max_epochs=2, logger=False, enable_checkpointing=False)
estimator = RAGPT(cls_num=len(le.classes_))
trainer.fit(estimator, train_dataloader)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type        | Params | Mode
----------------------------------------------
0 | model | RAGPTModule | 118 M  | train
----------------------------------------------
7.3 M     Trainable params
111 M     Non-trainable params
118 M     Total params
473.226   Total estimated model params size (MB)
19        Modules in train mode
232       Modules in eval mode

Training: |                                              | 0/? [00:00<?, ?it/s]
Training: |                                              | 0/? [00:00<?, ?it/s]
Epoch 0:   0%|                                           | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|███████████████████████████████████| 1/1 [00:17<00:00,  0.06it/s]
Epoch 0: 100%|███████████████████████████████████| 1/1 [00:17<00:00,  0.06it/s]
Epoch 0: 100%|███████████████████████████████████| 1/1 [00:17<00:00,  0.06it/s]
Epoch 0:   0%|                                           | 0/1 [00:00<?, ?it/s]
Epoch 1:   0%|                                           | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|███████████████████████████████████| 1/1 [00:16<00:00,  0.06it/s]
Epoch 1: 100%|███████████████████████████████████| 1/1 [00:16<00:00,  0.06it/s]
Epoch 1: 100%|███████████████████████████████████| 1/1 [00:16<00:00,  0.06it/s]`Trainer.fit` stopped: `max_epochs=2` reached.

Epoch 1: 100%|███████████████████████████████████| 1/1 [00:16<00:00,  0.06it/s]

Step 6: Advanced Usage: Track Metrics During Training

As any other model in Lightning, we can modify the internal functions. For instance, we can track loss and compute evaluation metrics during training.

trainer = Trainer(max_epochs=2, logger=False, enable_checkpointing=False)
estimator = RAGPT(cls_num=len(le.classes_))
estimator.loss_list = []
estimator.agg_loss_list = []
validation_step = estimator.validation_step

def compute_metric(*args):
    loss = validation_step(*args)
    estimator.loss_list.append(loss)
    return loss
estimator.validation_step = compute_metric

def agg_metric(*args):
    estimator.agg_loss_list.append(torch.stack(estimator.loss_list).mean())
    estimator.loss_list = []
estimator.on_validation_epoch_end = agg_metric

trainer.fit(estimator, train_dataloader, test_dataloader)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type        | Params | Mode
----------------------------------------------
0 | model | RAGPTModule | 118 M  | train
----------------------------------------------
7.3 M     Trainable params
111 M     Non-trainable params
118 M     Total params
473.226   Total estimated model params size (MB)
19        Modules in train mode
232       Modules in eval mode

Sanity Checking: |                                       | 0/? [00:00<?, ?it/s]
Sanity Checking: |                                       | 0/? [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|                      | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████████| 1/1 [00:03<00:00,  0.32it/s]


Training: |                                              | 0/? [00:00<?, ?it/s]
Training: |                                              | 0/? [00:00<?, ?it/s]
Epoch 0:   0%|                                           | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|███████████████████████████████████| 1/1 [00:16<00:00,  0.06it/s]
Epoch 0: 100%|███████████████████████████████████| 1/1 [00:16<00:00,  0.06it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation DataLoader 0:   0%|                           | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0: 100%|███████████████████| 1/1 [00:03<00:00,  0.33it/s]


Epoch 0: 100%|███████████████████████████████████| 1/1 [00:20<00:00,  0.05it/s]
Epoch 0: 100%|███████████████████████████████████| 1/1 [00:20<00:00,  0.05it/s]
Epoch 0:   0%|                                           | 0/1 [00:00<?, ?it/s]
Epoch 1:   0%|                                           | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|███████████████████████████████████| 1/1 [00:16<00:00,  0.06it/s]
Epoch 1: 100%|███████████████████████████████████| 1/1 [00:16<00:00,  0.06it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation DataLoader 0:   0%|                           | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0: 100%|███████████████████| 1/1 [00:03<00:00,  0.32it/s]


Epoch 1: 100%|███████████████████████████████████| 1/1 [00:20<00:00,  0.05it/s]
Epoch 1: 100%|███████████████████████████████████| 1/1 [00:20<00:00,  0.05it/s]`Trainer.fit` stopped: `max_epochs=2` reached.

Epoch 1: 100%|███████████████████████████████████| 1/1 [00:20<00:00,  0.05it/s]

Step 7: Evaluation

After training, we can evaluate predictions and visualize the results.

preds = trainer.predict(estimator, test_dataloader)
preds = [batch.softmax(dim=1) for batch in preds]
preds = [pred for batch in preds for pred in batch]
preds = torch.stack(preds).argmax(1).cpu()
losses = [i.item() for i in estimator.agg_loss_list]

nrows, ncols = 2,3
test_df = pd.concat(Xs_test, axis=1)
test_df = pd.concat([test_df, y_test.to_frame("label")], axis=1)
test_df = test_df.reset_index(drop=True)
preds = preds[test_df.index]
fig, axes = plt.subplots(nrows, ncols, constrained_layout=True)
for i, (i_row, row) in enumerate(test_df.sample(n=nrows*ncols, random_state=random_state).iterrows()):
    pred = preds[i_row]
    image_to_show = row["img"]
    caption = row["text"]
    real_class = le.classes_[row["label"]]
    ax = axes[i//ncols, i%ncols]
    ax.axis("off")
    if isinstance(image_to_show, str):
        image_to_show = Image.open(image_to_show).resize((512, 512), Image.Resampling.LANCZOS)
        ax.imshow(image_to_show)
    else:
        ax.plot(0.5, 0.5, 'rx', markersize=100, markeredgewidth=10)
    pred_class = le.classes_[pred]
    c = "green" if pred_class == real_class else "red"
    ax.set_title(f"Pred:{pred_class}; Real:{real_class}", **{"color":c})
    if isinstance(caption, str):
        caption = caption.split(" ")
        if len(caption) >=6:
            caption = caption[:len(caption)//2] + ["\n"] + caption[len(caption)//2:]
            caption = " ".join(caption)
        ax.annotate(caption, xy=(0.5, -0.08), xycoords='axes fraction', ha='center', va='top')
    else:
        ax.annotate("X", xy=(0.5, -0.08), xycoords='axes fraction', ha='center', va='top', color="red", fontsize=30)

shutil.rmtree(data_folder, ignore_errors=True)
Pred:cat; Real:cat, Pred:dog; Real:dog, Pred:dog; Real:dog, Pred:cat; Real:cat, Pred:dog; Real:dog, Pred:cat; Real:cat
Predicting: |                                            | 0/? [00:00<?, ?it/s]
Predicting: |                                            | 0/? [00:00<?, ?it/s]
Predicting DataLoader 0:   0%|                           | 0/1 [00:00<?, ?it/s]
Predicting DataLoader 0: 100%|███████████████████| 1/1 [00:02<00:00,  0.34it/s]
Predicting DataLoader 0: 100%|███████████████████| 1/1 [00:02<00:00,  0.34it/s]
ConfusionMatrixDisplay.from_predictions(y_true=y_test, y_pred=preds)
print("Testing metric:", matthews_corrcoef(y_true=y_test, y_pred=preds))
classify incomplete vision language
Testing metric: 0.8017837257372732

Despite using only 50 instances and minimal training, the performance was excellent thanks to the pretrained models.

Summary of results

We first built a memory bank with 40% independent vision-language samples using the iMML retrieve module to generate retrieval-augmented prompts with a multi-channel retriever (MCR). Subsequently, we trained a model using the RAGPT algorithm available in iMML under 25% randomly missing text and image modalities. The model demonstrated strong robustness on the test set.

This example is intentionally simplified, using only 50 instances for demonstration. For stronger performance and more reliable results, the full dataset and longer training should be used.

Conclusion

This example illustrates how iMML enables state-of-the-art performance in classification, even in the presence of significant modality incompleteness in vision-language datasets.

Total running time of the script: (4 minutes 14.551 seconds)

Gallery generated by Sphinx-Gallery