Load

M3Care Dataset Loader

class imml.load.M3CareDataset(Xs, y)[source]

Bases: object

This class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with M3Care.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: n_mods

    A list of different modalities.

  • y (array-like of shape (n_samples,)) -- Target vector relative to X.

Returns:

  • Xs (list of array-likes objects) --

    • Xs length: n_mods

    A list of different modalities for one sample.

  • y (array-like of shape (n_samples,)) -- Target vector relative to the sample.

  • observed_mod_indicator (array-like of shape (1, n_mods)) -- Boolean array-like indicating observed modalities for the sample.

See also

M3Care

Example

>>> from torch.utils.data import DataLoader
>>> import numpy as np
>>> import pandas as pd
>>> from imml.load import M3CareDataset
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> y = torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float()
>>> train_data = M3CareDataset(Xs=Xs, y=y)
>>> train_dataloader = DataLoader(dataset=train_data)
>>> next(iter(train_dataloader))

MRGCN Dataset Loader

class imml.load.MRGCNDataset(Xs: list, transform=None)[source]

Bases: object

This class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with MRGCN.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: n_mods

    • Xs[i] shape: (n_samples, n_features_i)

    A list of different modalities.

  • transform (list of callable, defult=None) -- A list of functions or transformations to apply to each sample in the dataset.

See also

MRGCN

Example

>>> import numpy as np
>>> import torch
>>> from imml.load import MRGCNDataset
>>> Xs = [torch.from_numpy(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> train_data = MRGCNDataset(Xs=Xs)

MUSE Dataset Loader

class imml.load.MUSEDataset(Xs, y)[source]

Bases: object

This class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with MUSE.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: n_mods

    A list of different modalities.

  • y (array-like of shape (n_samples,)) -- Target vector relative to X.

Returns:

  • Xs_idx (list of array-likes objects) --

    • Xs length: n_mods

    A list of different modalities for one sample.

  • y_idx (array-like of shape (n_samples,)) -- Target vector relative to the sample.

  • missing_mod_indicator (array-like of shape (1, n_mods)) -- Boolean array-like indicating missing modalities for the sample.

  • y_indicator (array-like of shape (1,)) -- Boolean array-like indicating observed label for the sample.

See also

MUSE

Example

>>> import numpy as np
>>> import pandas as pd
>>> import torch
>>> from torch.utils.data import DataLoader
>>> from imml.load import MUSEDataset
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> y = torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float()
>>> train_data = MUSEDataset(Xs=Xs, y=y)
>>> train_dataloader = DataLoader(dataset=train_data)
>>> next(iter(train_dataloader))

RAGPT Dataset Loader

class imml.load.RAGPTDataset(Xs: list, y, mcr=None, Xs_bank: list = None, y_bank=None, batch_size: int = 64, n_neighbors: int = 20, device: str = 'cpu', modalities: list = None, pretrained_model=None, processor=None, prompt_path: str = None, pretrained_vilt=None, tokenizer=None, image_processor=None, max_text_len: int = 40, max_image_len: int = 145)[source]

Bases: object

This class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with RAGPT. If it is used with torch.utils.data.DataLoader, the collate_fn argument of the DataLoader constructor should be RAGPTCollator.

Parameters:
  • batch_size (int, default=64) -- Batch size used for encoding inputs during memory bank creation and inference.

  • n_neighbors (int, default=20) -- Number of neighbors to retrieve per sample during prediction.

  • device (str, default="cpu") -- Device to use for model inference, typically "cpu" or "cuda".

  • modalities (list of str, default=None) -- Names of the modalities. Options are "text" and "image".

  • pretrained_model (transformers.PreTrainedModel, default=None) -- A pretrained HuggingFace model used for encoding multimodal inputs (e.g., CLIP model). If None, defaults to "openai/clip-vit-large-patch14-336".

  • processor (transformers.ProcessorMixin, default=None) -- HuggingFace processor corresponding to the pretrained model. Used to preprocess image/text inputs. If None, defaults to processor for "openai/clip-vit-large-patch14-336".

  • prompt_path (str, default=None) -- Path to save or load the generated prompts when generate_cap is True.

  • pretrained_vilt (transformers.PreTrainedModel, default=None) -- Pretrained model used for vision-language prompt generation. If None, defaults to ViltModel.from_pretrained('dandelin/vilt-b32-mlm').

  • tokenizer (transformers.BertTokenizer, default=None) -- Tokenizer used for text processing. If None, defaults to BertTokenizer.from_pretrained('dandelin/vilt-b32-mlm', do_lower_case=True).

  • image_processor (transformers.ViltImageProcessor, default=None) -- Image processor used with the ViLT model for image preprocessing. If None, defaults to ViltImageProcessor.from_pretrained('dandelin/vilt-b32-mlm').

  • max_text_len (int, default=40) -- Maximum token length for text inputs (used during prompt generation). Must not exceed the max_position_embeddings of the ViLT model (default: 40 for 'dandelin/vilt-b32-mlm').

  • max_image_len (int, default=145) -- Maximum token length for image inputs (used during prompt generation).

- mcr_
Type:

Multi-Channel Retriever (MCR) model for image-text retrieval.

- input_ids_list_
Type:

Unique identifier for each sample.

- img_path_list_
Type:

Path to the image files.

- attention_mask_list_
Type:

Attention masks for texts.

- token_type_ids_list_
Type:

Token ids for texts.

- i2i_r_l_list_
Type:

List of labels of the retrieved items for the image-to-image modality.

- t2t_r_l_list_
Type:

List of labels of the retrieved items for the text-to-text modality.

- label_list_
Type:

Label of the sample.

- prompt_image_path_
Type:

Path to the generated image prompt.

- prompt_text_path_
Type:

Path to the generated text prompt.

- observed_image_
Type:

Indicator of whether the image was observed.

- observed_text_
Type:

Indicator of whether the text was observed.

See also

RAGPT

Example

>>> from torch.utils.data import DataLoader
>>> from imml.load import RAGPTDataset
>>> Xs = [
        pd.DataFrame(["docs/figures/graph.png", "docs/figures/logo_imml.png"]),
        pd.DataFrame(["This is the graphical abstract of iMML.", "This is the logo of iMML."]),
    ]
>>> y = [0, 1, 0, 1]
>>> modalities = ["image", "text"]
>>> tmp_path = tempfile.mkdtemp()
>>> train_data = RAGPTDataset(Xs=Xs, y=y, Xs_bank=Xs, y_bank=y, modalities=modalities,
                              n_neighbors=1, prompt_path=str(tmp_path))