Retrieve

Multi-Channel Retriever (MCR)

class imml.retrieve.MCR(batch_size: int = 64, n_neighbors: int = 20, device: str = 'cpu', modalities: list = None, pretrained_model=None, processor=None, generate_cap: bool = False, prompt_path: str = None, pretrained_vilt=None, tokenizer=None, image_processor=None, max_text_len: int = 40, max_image_len: int = 145, save_memory_bank: bool = True)[source]

Bases: object

Multi-Channel Retriever (MCR). [1] [2]

MCR is a multimodal retrieval framework that enables similarity-based matching within modalities, even under missing modality settings. It builds a memory bank of multimodal embeddings and supports retrieval-augmented prompt generation for tasks like classification or generation.

Parameters:
  • batch_size (int, default=64) -- Batch size used for encoding inputs during memory bank creation and inference.

  • n_neighbors (int, default=20) -- Number of neighbors to retrieve per sample during prediction.

  • device (str, default="cpu") -- Device to use for model inference, typically "cpu" or "cuda".

  • modalities (list of str, default=None) -- Names of the modalities. Options are "text" and "image".

  • pretrained_model (transformers.PreTrainedModel, default=None) -- A pretrained HuggingFace model used for encoding multimodal inputs (e.g., CLIP model). If None, defaults to "openai/clip-vit-large-patch14-336".

  • processor (transformers.ProcessorMixin, default=None) -- HuggingFace processor corresponding to the pretrained model. Used to preprocess image/text inputs. If None, defaults to processor for "openai/clip-vit-large-patch14-336".

  • generate_cap (bool, default=False) -- Whether to generate retrieval-based prompts.

  • prompt_path (str, default=None) -- Path to save or load the generated prompts when generate_cap is True.

  • pretrained_vilt (transformers.PreTrainedModel, default=None) -- Pretrained model used for vision-language prompt generation. If None, defaults to ViltModel.from_pretrained('dandelin/vilt-b32-mlm').

  • tokenizer (transformers.BertTokenizer, default=None) -- Tokenizer used for text processing. If None, defaults to BertTokenizer.from_pretrained('dandelin/vilt-b32-mlm', do_lower_case=True).

  • image_processor (transformers.ViltImageProcessor, default=None) -- Image processor used with the ViLT model for image preprocessing. If None, defaults to ViltImageProcessor.from_pretrained('dandelin/vilt-b32-mlm').

  • max_text_len (int, default=40) -- Maximum token length for text inputs (used during prompt generation). Must not exceed the max_position_embeddings of the ViLT model (default: 40 for 'dandelin/vilt-b32-mlm').

  • max_image_len (int, default=145) -- Maximum token length for image inputs (used during prompt generation).

  • save_memory_bank (bool, default=True) -- Whether to save the memory bank of embeddings after fitting as an attribute. If False, the memory bank is returned as output during fit.

memory_bank_

DataFrame storing encoded modality representations for retrieval. Only if save_memory_bank is True. The columns are: - item_id: Unique identifier for each sample. - img_path: Path to the image file. - text: Textual content of the sample. - q_i: Image embedding. - q_t: Text embedding. - label: Label of the sample. - prompt_image_path: Path to the generated image prompt. Only if generate_cap is True. - prompt_text_path: Path to the generated text prompt. Only if generate_cap is True.

Type:

pd.DataFrame (n_samples, 6)

References

See also

RAGPT, RAGPTDataset

Example

>>> import pandas as pd
>>> from imml.retrieve import MCR
>>> Xs = [
        pd.DataFrame(["docs/figures/graph.png", "docs/figures/logo_imml.png",
                      "docs/figures/graph.png", "docs/figures/logo_imml.png"]),
        pd.DataFrame(["This is the graphical abstract of iMML.", "This is the logo of iMML.",
                      "This is the graphical abstract of iMML.", "This is the logo of iMML."]),
>>> ]
>>> y = [0, 1, 0, 1]
>>> modalities = ["image", "text"]
>>> estimator = MCR(modalities=modalities, n_neighbors=1)
>>> estimator.fit(Xs=Xs, y=y)
>>> memory_bank = estimator.memory_bank_
>>> estimator.predict(Xs=Xs)
fit(Xs: list, y)[source]

Fit the transformer to the input data.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: 2

    • Xs[i] shape: (n_samples_i, 1)

    A list with images and texts.

  • y (array-like of shape (n_samples,)) -- Target vector relative to X.

Returns:

self

Return type:

Fitted retriever (or memory_bank if save_memory_bank=False)

predict(Xs: list, memory_bank: DataFrame = None, n_neighbors: int = None)[source]

Retrieve the most similar instances.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: 2

    • Xs[i] shape: (n_samples_i, 1)

    A list with images and texts.

  • memory_bank (pd.DataFrame (n_samples, 10)) -- Memory bank generated during fit. If None, the memory bank stored in the estimator is used.

  • n_neighbors (int, default=None) -- Number of neighbors to retrieve per sample during prediction. If None, the value stored in the estimator is used,

Returns:

pred

Return type:

Dictionary with the ids, similarities and labels of the retrieved items for each modality.

fit_predict(Xs: list, y, n_neighbors: int = None)[source]

Fit the transformer to the input data and retrieve the most similar instances.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: 2

    • Xs[i] shape: (n_samples_i, 1)

    A list with images and texts.

  • y (array-like of shape (n_samples,)) -- Target vector relative to X.

  • n_neighbors (int, default=None) -- Number of neighbors to retrieve per sample during prediction. If None, the value stored in the estimator is used,

Returns:

pred

Return type:

Dictionary with the ids, similarities and labels of the retrieved items for each modality.

transform(Xs: list, y, memory_bank: DataFrame = None, n_neighbors: int = None)[source]

Generate retrieval-augmented prompts.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: 2

    • Xs[i] shape: (n_samples_i, 1)

    A list with images and texts.

  • y (array-like of shape (n_samples,)) -- Target vector relative to X.

  • memory_bank (pd.DataFrame (n_samples, 10)) -- Memory bank generated during fit. If None, the memory bank stored in the estimator is used.

  • n_neighbors (int, default=None) -- Number of neighbors to retrieve per sample during prediction. If None, the value stored in the estimator is used.

Returns:

database -- A database with the retrieval-augmented prompts. It contains the following columns: - item_id: Unique identifier for each sample. - img_path: Path to the image file. - text: Textual content of the sample. - label: Label of the sample. - observed_image: Indicator of whether the image was observed. - observed_text: Indicator of whether the text was observed. - i2i_id_list: List of ids of the retrieved items for the image-to-image modality. - i2i_sims_list: List of similarities of the retrieved items for the image-to-image modality. - i2i_label_list: List of labels of the retrieved items for the image-to-image modality. - prompt_image_path: Path to the generated image prompt. Only if generate_cap is True. - t2t_id_list: List of ids of the retrieved items for the text-to-text modality. - t2t_sims_list: List of similarities of the retrieved items for the text-to-text modality. - t2t_label_list: List of labels of the retrieved items for the text-to-text modality. - prompt_text_path: Path to the generated text prompt. Only if generate_cap is True.

Return type:

pd.DataFrame (n_samples, 14)

fit_transform(Xs: list, y, n_neighbors: int = None)[source]

Fit the transformer to the input data and generate retrieval-augmented prompts.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: 2

    • Xs[i] shape: (n_samples_i, 1)

    A list with images and texts.

  • y (array-like of shape (n_samples,)) -- Target vector relative to X.

  • n_neighbors (int, default=None) -- Number of neighbors to retrieve per sample during prediction. If None, the value stored in the estimator is used,

Returns:

database -- A database with the retrieval-augmented prompts. It contains the following columns: - item_id: Unique identifier for each sample. - img_path: Path to the image file. - text: Textual content of the sample. - label: Label of the sample. - observed_image: Indicator of whether the image was observed. - observed_text: Indicator of whether the text was observed. - i2i_id_list: List of ids of the retrieved items for the image-to-image modality. - i2i_sims_list: List of similarities of the retrieved items for the image-to-image modality. - i2i_label_list: List of labels of the retrieved items for the image-to-image modality. - prompt_image_path: Path to the generated image prompt. Only if generate_cap is True. - t2t_id_list: List of ids of the retrieved items for the text-to-text modality. - t2t_sims_list: List of similarities of the retrieved items for the text-to-text modality. - t2t_label_list: List of labels of the retrieved items for the text-to-text modality. - prompt_text_path: Path to the generated text prompt. Only if generate_cap is True.

Return type:

pd.DataFrame (n_samples, 14)