"""
==========================================================================
Retrieval on a vision–language dataset (flickr30k)
==========================================================================

This tutorial demonstrates how to retrieve samples from an incomplete vision–language dataset using `iMML`.
We will use the ``MCR`` retriever to find similar items across modalities (image/text) even when one modality
is missing. The example uses the public `nlphuji/flickr30k
<https://huggingface.co/datasets/nlphuji/flickr30k>`__ dataset from `Hugging Face Datasets
<https://huggingface.co/datasets>`__, so you don't need to prepare files manually.

What you will learn:

- How to load a vision–language dataset.
- How to build a memory bank with ``MCR`` for cross-modal retrieval.
- How to retrieve relevant items with missing modalities.
- How to visualize top retrieved examples for qualitative inspection.

This tutorial is fully reproducible. You can swap the loading section with your own data by constructing two
parallel lists: image paths and texts for each sample.
"""

# sphinx_gallery_thumbnail_number = 1

# License: BSD 3-Clause License

###################################
# Step 0: Prerequisites
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# To run this tutorial, install the extras for deep learning:
#   pip install imml[deep]
# Additionally, we will use the Hugging Face Datasets library to load Flickr30k:
#   pip install datasets


###################################
# Step 1: Import required libraries
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import lightning as L
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from datasets import load_dataset
import shutil

from imml.ampute import Amputer
from imml.retrieve import MCR

################################
# Step 2: Prepare the dataset
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We use the Flickr30k dataset, a public vision–language dataset with images and captions available on
# `Hugging Face Datasets <https://huggingface.co/datasets>`__ as `nlphuji/flickr30k
# <https://huggingface.co/datasets/nlphuji/flickr30k>`__. For retrieval, we will use the ``MCR`` method from the
# retrieve module.

random_state = 42
L.seed_everything(random_state)

# Local working directory (images will be saved here so MCR can read paths)
data_folder = "flickr30k"
folder_images = os.path.join(data_folder, "imgs")
os.makedirs(folder_images, exist_ok=True)

# Load the dataset
ds = load_dataset("nlphuji/flickr30k", split="test[:100]")

# Build a DataFrame with image paths and captions. We persist images to disk because
# the retriever expects paths.
n_total = len(ds)
rows = []
for i in range(n_total):
    ex = ds[i]
    img = ex.get("image", None)
    caption = ex.get("caption", None)[0]
    img_path = os.path.join(folder_images, f"{i:06d}.jpg")
    img.save(img_path)
    rows.append({"img": img_path, "text": caption})

df = pd.DataFrame(rows)

# Split into 80% train and 20% test sets
train_df = df.sample(frac=0.8, random_state=random_state)
test_df = df.drop(index=train_df.index)
print("train_df", train_df.shape)
train_df.head()


###################################################
# Step 3: Simulate missing modalities
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# To reflect realistic scenarios, we randomly introduce missing data using ``Amputer``. In this case, 80% of test samples
# will have either text or image missing. You can change this parameter for more or less amount of incompleteness.

Xs_test = [test_df[["img"]], test_df[["text"]]]
amputer = Amputer(p=0.8, random_state=0)
Xs_test = amputer.fit_transform(Xs_test)


########################################################
# Step 4: Generate the memory bank
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We build the retriever with ``MCR``.
modalities = ["image", "text"]
estimator = MCR(batch_size=64, modalities=modalities, save_memory_bank=True)

Xs_train = [train_df[["img"]], train_df[["text"]]]
# Use dummy labels for API compatibility (labels are not provided in Flickr30k)
y_train = pd.Series(np.zeros(len(train_df)), index=train_df.index)

estimator.fit(Xs=Xs_train, y=y_train)
memory_bank = estimator.memory_bank_


########################################################
# Step 5: Retrieve
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We retrieved the most similar items for the test set.

preds = estimator.predict(Xs=Xs_test, n_neighbors=2)

################################
# This is the content of the prediction.
preds.keys()
################################
preds["image"].keys()
################################
preds["text"].keys()


########################################################
# Step 6: Visualize the retrieved instances
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We can visualize the top-2 retrieved instances for a given target sample.
# Here we focus on qualitative inspection: looking at the images and reading the captions
# of the retrieved items to assess whether they are semantically similar to the target.
#
# The target instance is displayed in the leftmost column, followed by the most similar instances in descending
# order of similarity. Note that some instances have missing modalities, which will not appear in the plot. In this
# example, the first two instances are missing the image modality, while the last one is missing the text modality.

nrows, ncols = 3,3
fig, axes = plt.subplots(nrows, ncols, constrained_layout=True)
for i in range(nrows*ncols):
    row, col = i//ncols, i%ncols
    ax = axes[i//ncols, col]
    ax.axis("off")
    if col == 0:
        image_to_show = Xs_test[0].iloc[row][0]
        caption = Xs_test[1].iloc[row][0]
        ax.set_title("Target instance")
    else:
        col -= 1
        try:
            retrieved_instance = preds["image"]["id"][row][col]
        except IndexError:
            retrieved_instance = preds["text"]["id"][row][col]
        retrieved_instance = memory_bank.loc[retrieved_instance]
        image_to_show = retrieved_instance["img_path"]
        caption = retrieved_instance["text"]
        ax.set_title(f"Top-{col+1}")

    if isinstance(image_to_show, str):
        image_to_show = Image.open(image_to_show).resize((512, 512), Image.Resampling.LANCZOS)
        ax.imshow(image_to_show)
    else:
        ax.plot(0.5, 0.5, 'rx', markersize=100, markeredgewidth=10)
    if isinstance(caption, str):
        caption = caption.split(" ")
        if len(caption) >= 6:
            caption = caption[:len(caption) // 4] + ["\n"] + caption[len(caption) // 4:len(caption) // 4 * 2] + \
                      ["\n"] + caption[len(caption) // 4 * 2:len(caption) // 4 * 3] + ["\n"] + caption[
                          len(caption) // 4 * 3:]
            caption = " ".join(caption)
        ax.annotate(caption, xy=(0.5, -0.08), xycoords='axes fraction', ha='center', va='top')
    else:
        ax.annotate("X", xy=(0.5, -0.08), xycoords='axes fraction', ha='center', va='top', color="red", fontsize=30)

shutil.rmtree(data_folder, ignore_errors=True)

###################################
# Summary of results
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We used the ``MCR`` retriever from `iMML` to identify the most relevant instances from a
# memory bank, even when one of the modalities (image or text) was missing.
#
# This example is intentionally simplified, using only a few instances for demonstration.
# For stronger performance and more reliable results, the full dataset should be used.

###################################
# Conclusion
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# This example illustrates how `iMML` enables robust retrieval and similarity search in vision-language datasets,
# even in the presence of missing modalities.