Load

M3Care Dataset Loader

class imml.load.M3CareDataset(Xs, y, observed_mod_indicator)[source]

Bases: Dataset

This class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with M3Care.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: n_mods

    A list of different modalities.

  • y (array-like of shape (n_samples,)) -- Target vector relative to X.

  • observed_mod_indicator (array-like of shape (n_samples, n_mods)) -- Boolean array-like indicating observed modalities for each sample.

Returns:

  • Xs_idx (list of array-likes objects) --

    • Xs length: n_mods

    A list of different modalities for one sample.

  • y_idx (array-like of shape (n_samples,)) -- Target vector relative to the sample.

  • observed_mod_indicator (array-like of shape (1, n_mods)) -- Boolean array-like indicating observed modalities for the sample.

Example

>>> import numpy as np
>>> import pandas as pd
>>> from imml.load import M3CareDataset
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> Xs = [torch.from_numpy(X.values).float() for X in Xs]
>>> observed_mod_indicator = torch.from_numpy(get_observed_mod_indicator(Xs).values)
>>> y = torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float()
>>> train_data = M3CareDataset(Xs=Xs, observed_mod_indicator=observed_mod_indicator, y=y)

MRGCN Dataset Loader

class imml.load.MRGCNDataset(Xs: List, transform=None)[source]

Bases: Dataset

This class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with MRGCN.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: n_mods

    • Xs[i] shape: (n_samples, n_features_i)

    A list of different modalities.

  • transform (list of callable, defult=None) -- A list of functions or transformations to apply to each sample in the dataset.

Example

>>> import numpy as np
>>> import torch
>>> from imml.load import MRGCNDataset
>>> Xs = [torch.from_numpy(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> train_data = MRGCNDataset(Xs=Xs)

MUSE Dataset Loader

class imml.load.MUSEDataset(Xs, y, observed_mod_indicator, y_indicator)[source]

Bases: Dataset

This class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with MUSE.

Parameters:
  • Xs (list of array-likes objects) --

    • Xs length: n_mods

    A list of different modalities.

  • y (array-like of shape (n_samples,)) -- Target vector relative to X.

  • observed_mod_indicator (array-like of shape (n_samples, n_mods)) -- Boolean array-like indicating observed modalities for each sample.

  • y_indicator (array-like of shape (n_samples,)) -- Boolean array-like indicating observed label for each sample.

Returns:

  • Xs_idx (list of array-likes objects) --

    • Xs length: n_mods

    A list of different modalities for one sample.

  • y_idx (array-like of shape (n_samples,)) -- Target vector relative to the sample.

  • observed_mod_indicator (array-like of shape (1, n_mods)) -- Boolean array-like indicating observed modalities for the sample.

  • y_indicator (array-like of shape (1,)) -- Boolean array-like indicating observed label for the sample.

Example

>>> import numpy as np
>>> import pandas as pd
>>> from imml.load import MUSEDataset
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> Xs = [torch.from_numpy(X.values).float() for X in Xs]
>>> observed_mod_indicator = torch.from_numpy(get_observed_mod_indicator(Xs).values)
>>> y = torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float()
>>> y_indicator = torch.ones((len(Xs[0]))).bool()
>>> train_data = MUSEDataset(Xs=Xs, observed_mod_indicator=observed_mod_indicator,
                             y=y, y_indicator=y_indicator)

RAGPT Dataset Loader

class imml.load.RAGPTDataset(database: DataFrame, max_text_len: int = 128)[source]

Bases: Dataset

This class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with RAGPT. If it is used with torch.utils.data.DataLoader, the collate_fn argument of the DataLoader constructor should be RAGPTCollator.

Parameters:
  • database (pd.DataFrame (n_samples, 14)) --

    A database with the retrieval-augmented prompts created by MCR. It must contain the following columns:
    • item_id: Unique identifier for each sample.

    • img_path: Path to the image file.

    • text: Textual content of the sample.

    • label: Label of the sample.

    • observed_image: Indicator of whether the image was observed.

    • observed_text: Indicator of whether the text was observed.

    • i2i_id_list: List of ids of the retrieved items for the image-to-image modality.

    • i2i_sims_list: List of similarities of the retrieved items for the image-to-image modality.

    • i2i_label_list: List of labels of the retrieved items for the image-to-image modality.

    • prompt_image_path: Path to the generated image prompt. Only if generate_cap is True.

    • t2t_id_list: List of ids of the retrieved items for the text-to-text modality.

    • t2t_sims_list: List of similarities of the retrieved items for the text-to-text modality.

    • t2t_label_list: List of labels of the retrieved items for the text-to-text modality.

    • prompt_text_path: Path to the generated text prompt. Only if generate_cap is True.

  • max_text_len (int, default=128) -- Maximum token length for text inputs (used during prompt generation).

Returns:

sample --

Dictionary with the following keys for one sample:
  • image: Image of the sample.

  • text: Textual content of the sample.

  • label: Label of the sample.

  • r_t_list: List with retrieved textual content for the sample.

  • r_i_list: List with retrieved image content for the sample.

  • r_l_list: List with retrieved labels for the sample.

  • observed_text: True if the text is observed, False otherwise.

  • observed_image: True if the image is observed, False otherwise.

Return type:

dict

Example

>>> from torch.utils.data import DataLoader
>>> from imml.load import RAGPTDataset
>>> from imml.retrieve import MCR
>>> images = ["docs/figures/graph.png", "docs/figures/logo_imml.png",
              "docs/figures/graph.png", "docs/figures/logo_imml.png"]
>>> texts = ["This is the graphical abstract of iMML.", "This is the logo of iMML.",
             "This is the graphical abstract of iMML.", "This is the logo of iMML."]
>>> Xs = [images, texts]
>>> y = [0, 1, 0, 1]
>>> modalities = ["image", "text"]
>>> estimator = MCR(modalities=modalities)
>>> database = estimator.fit_transform(Xs=Xs, y=y)
>>> train_data = RAGPTDataset(database=database)
>>> train_dataloader = DataLoader(train_data, collate_fn=RAGPTCollator())