Retrieve¶
Multi-Channel Retriever (MCR)¶
- class imml.retrieve.MCR(batch_size: int = 64, n_neighbors: int = 20, device: str = 'cpu', modalities: list = None, pretrained_model=None, processor=None, generate_cap: bool = False, prompt_path: str = None, pretrained_vilt=None, tokenizer=None, image_processor=None, max_text_len: int = 128, max_image_len: int = 145, save_memory_bank: bool = True)[source]¶
Bases:
ModuleMulti-Channel Retriever (MCR). [1] [2]
MCR is a multimodal retrieval framework that enables similarity-based matching within modalities, even under missing modality settings. It builds a memory bank of multimodal embeddings and supports retrieval-augmented prompt generation for tasks like classification or generation.
- Parameters:
batch_size (int, default=64) -- Batch size used for encoding inputs during memory bank creation and inference.
n_neighbors (int, default=20) -- Number of neighbors to retrieve per sample during prediction.
device (str, default="cpu") -- Device to use for model inference, typically "cpu" or "cuda".
modalities (list of str, default=None) -- Names of the modalities. Options are "text" and "image".
pretrained_model (transformers.PreTrainedModel, default=None) -- A pretrained HuggingFace model used for encoding multimodal inputs (e.g., CLIP model). If None, defaults to "openai/clip-vit-large-patch14-336".
processor (transformers.ProcessorMixin, default=None) -- HuggingFace processor corresponding to the pretrained model. Used to preprocess image/text inputs. If None, defaults to processor for "openai/clip-vit-large-patch14-336".
generate_cap (bool, default=False) -- Whether to generate retrieval-based prompts.
prompt_path (str, default=None) -- Path to save or load the generated prompts when generate_cap is True.
pretrained_vilt (transformers.PreTrainedModel, default=None) -- Pretrained model used for vision-language prompt generation. If None, defaults to ViltModel.from_pretrained('dandelin/vilt-b32-mlm').
tokenizer (transformers.BertTokenizer, default=None) -- Tokenizer used for text processing. If None, defaults to BertTokenizer.from_pretrained('dandelin/vilt-b32-mlm', do_lower_case=True).
image_processor (transformers.ViltImageProcessor, default=None) -- Image processor used with the ViLT model for image preprocessing. If None, defaults to ViltImageProcessor.from_pretrained('dandelin/vilt-b32-mlm').
max_text_len (int, default=128) -- Maximum token length for text inputs (used during prompt generation).
max_image_len (int, default=145) -- Maximum token length for image inputs (used during prompt generation).
save_memory_bank (bool, default=True) -- Whether to save the memory bank of embeddings after fitting as an attribute. If False, the memory bank is returned as output during fit.
- memory_bank_¶
DataFrame storing encoded modality representations for retrieval. Only if save_memory_bank is True. The columns are: - item_id: Unique identifier for each sample. - img_path: Path to the image file. - text: Textual content of the sample. - q_i: Image embedding. - q_t: Text embedding. - label: Label of the sample. - prompt_image_path: Path to the generated image prompt. Only if generate_cap is True. - prompt_text_path: Path to the generated text prompt. Only if generate_cap is True.
- Type:
pd.DataFrame (n_samples, 6)
References
Example
>>> from imml.retrieve import MCR >>> images = ["docs/figures/graph.png", "docs/figures/logo_imml.png", "docs/figures/graph.png", "docs/figures/logo_imml.png"] >>> texts = ["This is the graphical abstract of iMML.", "This is the logo of iMML.", "This is the graphical abstract of iMML.", "This is the logo of iMML."] >>> Xs = [images, texts] >>> y = [0, 1, 0, 1] >>> modalities = ["image", "text"] >>> estimator = MCR(modalities=modalities) >>> estimator.fit(Xs=Xs, y=y) >>> memory_bank = estimator.memory_bank_ >>> preds = estimator.predict(Xs=Xs)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- fit(Xs: list, y)[source]¶
Fit the transformer to the input data.
- Parameters:
Xs (list of array-likes objects) --
Xs length: 2
Xs[i] shape: (n_samples_i, 1)
A list with images and texts.
y (array-like of shape (n_samples,)) -- Target vector relative to X.
- Returns:
self
- Return type:
Fitted estimator. Or memory_bank if save_memory_bank is False.
- predict(Xs: list = None, memory_bank: DataFrame = None, n_neighbors: int = None)[source]¶
Retrieve the most similar instances.
- Parameters:
Xs (list of array-likes objects) --
Xs length: 2
Xs[i] shape: (n_samples_i, 1)
A list with images and texts.
memory_bank (pd.DataFrame (n_samples, 10)) -- Memory bank generated during fit. If None, the memory bank stored in the estimator is used.
n_neighbors (int, default=None) -- Number of neighbors to retrieve per sample during prediction. If None, the value stored in the estimator is used,
- Returns:
pred
- Return type:
Dictionary with the ids, similarities and labels of the retrieved items for each modality.
- fit_predict(Xs: list, y, n_neighbors: int = None)[source]¶
Fit the transformer to the input data and retrieve the most similar instances.
- Parameters:
Xs (list of array-likes objects) --
Xs length: 2
Xs[i] shape: (n_samples_i, 1)
A list with images and texts.
y (array-like of shape (n_samples,)) -- Target vector relative to X.
n_neighbors (int, default=None) -- Number of neighbors to retrieve per sample during prediction. If None, the value stored in the estimator is used,
- Returns:
pred
- Return type:
Dictionary with the ids, similarities and labels of the retrieved items for each modality.
- transform(Xs: list, y, memory_bank: DataFrame = None, n_neighbors: int = None)[source]¶
Generate retrieval-augmented prompts.
- Parameters:
Xs (list of array-likes objects) --
Xs length: 2
Xs[i] shape: (n_samples_i, 1)
A list with images and texts.
y (array-like of shape (n_samples,)) -- Target vector relative to X.
memory_bank (pd.DataFrame (n_samples, 10)) -- Memory bank generated during fit. If None, the memory bank stored in the estimator is used.
n_neighbors (int, default=None) -- Number of neighbors to retrieve per sample during prediction. If None, the value stored in the estimator is used.
- Returns:
database -- A database with the retrieval-augmented prompts. It contains the following columns: - item_id: Unique identifier for each sample. - img_path: Path to the image file. - text: Textual content of the sample. - label: Label of the sample. - observed_image: Indicator of whether the image was observed. - observed_text: Indicator of whether the text was observed. - i2i_id_list: List of ids of the retrieved items for the image-to-image modality. - i2i_sims_list: List of similarities of the retrieved items for the image-to-image modality. - i2i_label_list: List of labels of the retrieved items for the image-to-image modality. - prompt_image_path: Path to the generated image prompt. Only if generate_cap is True. - t2t_id_list: List of ids of the retrieved items for the text-to-text modality. - t2t_sims_list: List of similarities of the retrieved items for the text-to-text modality. - t2t_label_list: List of labels of the retrieved items for the text-to-text modality. - prompt_text_path: Path to the generated text prompt. Only if generate_cap is True.
- Return type:
pd.DataFrame (n_samples, 14)
- fit_transform(Xs: list, y, n_neighbors: int = None)[source]¶
Fit the transformer to the input data and generate retrieval-augmented prompts.
- Parameters:
Xs (list of array-likes objects) --
Xs length: 2
Xs[i] shape: (n_samples_i, 1)
A list with images and texts.
y (array-like of shape (n_samples,)) -- Target vector relative to X.
n_neighbors (int, default=None) -- Number of neighbors to retrieve per sample during prediction. If None, the value stored in the estimator is used,
- Returns:
database -- A database with the retrieval-augmented prompts. It contains the following columns: - item_id: Unique identifier for each sample. - img_path: Path to the image file. - text: Textual content of the sample. - label: Label of the sample. - observed_image: Indicator of whether the image was observed. - observed_text: Indicator of whether the text was observed. - i2i_id_list: List of ids of the retrieved items for the image-to-image modality. - i2i_sims_list: List of similarities of the retrieved items for the image-to-image modality. - i2i_label_list: List of labels of the retrieved items for the image-to-image modality. - prompt_image_path: Path to the generated image prompt. Only if generate_cap is True. - t2t_id_list: List of ids of the retrieved items for the text-to-text modality. - t2t_sims_list: List of similarities of the retrieved items for the text-to-text modality. - t2t_label_list: List of labels of the retrieved items for the text-to-text modality. - prompt_text_path: Path to the generated text prompt. Only if generate_cap is True.
- Return type:
pd.DataFrame (n_samples, 14)