Retrieval on a vision–language dataset (flickr30k)

This tutorial demonstrates how to retrieve samples from an incomplete vision–language dataset using iMML. We will use the MCR retriever to find similar items across modalities (image/text) even when one modality is missing. The example uses the public nlphuji/flickr30k dataset from Hugging Face Datasets, so you don't need to prepare files manually.

What you will learn:

  • How to load a vision–language dataset.

  • How to build a memory bank with MCR for cross-modal retrieval.

  • How to retrieve relevant items with missing modalities.

  • How to visualize top retrieved examples for qualitative inspection.

This tutorial is fully reproducible. You can swap the loading section with your own data by constructing two parallel lists: image paths and texts for each sample.

# sphinx_gallery_thumbnail_number = 1

# License: BSD 3-Clause License

Step 0: Prerequisites

To run this tutorial, install the extras for deep learning:

pip install imml[deep]

Additionally, we will use the Hugging Face Datasets library to load Flickr30k:

pip install datasets

Step 1: Import required libraries

import lightning as L
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from datasets import load_dataset
import shutil

from imml.ampute import Amputer
from imml.retrieve import MCR

Step 2: Prepare the dataset

We use the Flickr30k dataset, a public vision–language dataset with images and captions available on Hugging Face Datasets as nlphuji/flickr30k. For retrieval, we will use the MCR method from the retrieve module.

random_state = 42
L.seed_everything(random_state)

# Local working directory (images will be saved here so MCR can read paths)
data_folder = "flickr30k"
folder_images = os.path.join(data_folder, "imgs")
os.makedirs(folder_images, exist_ok=True)

# Load the dataset
ds = load_dataset("nlphuji/flickr30k", split="test[:100]")

# Build a DataFrame with image paths and captions. We persist images to disk because
# the retriever expects paths.
n_total = len(ds)
rows = []
for i in range(n_total):
    ex = ds[i]
    img = ex.get("image", None)
    caption = ex.get("caption", None)[0]
    img_path = os.path.join(folder_images, f"{i:06d}.jpg")
    img.save(img_path)
    rows.append({"img": img_path, "text": caption})

df = pd.DataFrame(rows)

# Split into 80% train and 20% test sets
train_df = df.sample(frac=0.8, random_state=random_state)
test_df = df.drop(index=train_df.index)
print("train_df", train_df.shape)
train_df.head()
train_df (80, 2)
img text
83 flickr30k/imgs/000083.jpg A man is drilling through the frozen ice of a ...
53 flickr30k/imgs/000053.jpg People on two balconies and a man climbing up ...
70 flickr30k/imgs/000070.jpg Large brown dog running away from the sprinkle...
45 flickr30k/imgs/000045.jpg Bride and groom walking side by side out of fo...
44 flickr30k/imgs/000044.jpg A man in black approaches a strange silver obj...


Step 3: Simulate missing modalities

To reflect realistic scenarios, we randomly introduce missing data using Amputer. In this case, 80% of test samples will have either text or image missing. You can change this parameter for more or less amount of incompleteness.

Xs_test = [test_df[["img"]], test_df[["text"]]]
amputer = Amputer(p=0.8, random_state=0)
Xs_test = amputer.fit_transform(Xs_test)

Step 4: Generate the memory bank

We build the retriever with MCR.

modalities = ["image", "text"]
estimator = MCR(batch_size=64, modalities=modalities, save_memory_bank=True)

Xs_train = [train_df[["img"]], train_df[["text"]]]
# Use dummy labels for API compatibility (labels are not provided in Flickr30k)
y_train = pd.Series(np.zeros(len(train_df)), index=train_df.index)

estimator.fit(Xs=Xs_train, y=y_train)
memory_bank = estimator.memory_bank_

Step 5: Retrieve

We retrieved the most similar items for the test set.

preds = estimator.predict(Xs=Xs_test, n_neighbors=2)

This is the content of the prediction.

preds.keys()
dict_keys(['image', 'text'])
preds["image"].keys()
dict_keys(['id', 'sims', 'label'])
preds["text"].keys()
dict_keys(['id', 'sims', 'label'])

Step 6: Visualize the retrieved instances

We can visualize the top-2 retrieved instances for a given target sample. Here we focus on qualitative inspection: looking at the images and reading the captions of the retrieved items to assess whether they are semantically similar to the target.

The target instance is displayed in the leftmost column, followed by the most similar instances in descending order of similarity. Note that some instances have missing modalities, which will not appear in the plot. In this example, the first two instances are missing the image modality, while the last one is missing the text modality.

nrows, ncols = 3,3
fig, axes = plt.subplots(nrows, ncols, constrained_layout=True)
for i in range(nrows*ncols):
    row, col = i//ncols, i%ncols
    ax = axes[i//ncols, col]
    ax.axis("off")
    if col == 0:
        image_to_show = Xs_test[0].iloc[row][0]
        caption = Xs_test[1].iloc[row][0]
        ax.set_title("Target instance")
    else:
        col -= 1
        try:
            retrieved_instance = preds["image"]["id"][row][col]
        except IndexError:
            retrieved_instance = preds["text"]["id"][row][col]
        retrieved_instance = memory_bank.loc[retrieved_instance]
        image_to_show = retrieved_instance["img_path"]
        caption = retrieved_instance["text"]
        ax.set_title(f"Top-{col+1}")

    if isinstance(image_to_show, str):
        image_to_show = Image.open(image_to_show).resize((512, 512), Image.Resampling.LANCZOS)
        ax.imshow(image_to_show)
    else:
        ax.plot(0.5, 0.5, 'rx', markersize=100, markeredgewidth=10)
    if isinstance(caption, str):
        caption = caption.split(" ")
        if len(caption) >= 6:
            caption = caption[:len(caption) // 4] + ["\n"] + caption[len(caption) // 4:len(caption) // 4 * 2] + \
                      ["\n"] + caption[len(caption) // 4 * 2:len(caption) // 4 * 3] + ["\n"] + caption[
                          len(caption) // 4 * 3:]
            caption = " ".join(caption)
        ax.annotate(caption, xy=(0.5, -0.08), xycoords='axes fraction', ha='center', va='top')
    else:
        ax.annotate("X", xy=(0.5, -0.08), xycoords='axes fraction', ha='center', va='top', color="red", fontsize=30)

shutil.rmtree(data_folder, ignore_errors=True)
Target instance, Top-1, Top-2, Target instance, Top-1, Top-2, Target instance, Top-1, Top-2

Summary of results

We used the MCR retriever from iMML to identify the most relevant instances from a memory bank, even when one of the modalities (image or text) was missing.

This example is intentionally simplified, using only a few instances for demonstration. For stronger performance and more reliable results, the full dataset should be used.

Conclusion

This example illustrates how iMML enables robust retrieval and similarity search in vision-language datasets, even in the presence of missing modalities.

Total running time of the script: (3 minutes 58.062 seconds)

Gallery generated by Sphinx-Gallery