Load¶
M3Care Dataset Loader¶
- class imml.load.M3CareDataset(Xs, y)[source]¶
Bases:
objectThis class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with M3Care.
- Parameters:
Xs (list of array-likes objects) --
Xs length: n_mods
A list of different modalities.
y (array-like of shape (n_samples,)) -- Target vector relative to X.
- Returns:
Xs (list of array-likes objects) --
Xs length: n_mods
A list of different modalities for one sample.
y (array-like of shape (n_samples,)) -- Target vector relative to the sample.
observed_mod_indicator (array-like of shape (1, n_mods)) -- Boolean array-like indicating observed modalities for the sample.
See also
Example
>>> from torch.utils.data import DataLoader >>> import numpy as np >>> import pandas as pd >>> from imml.load import M3CareDataset >>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)] >>> y = torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float() >>> train_data = M3CareDataset(Xs=Xs, y=y) >>> train_dataloader = DataLoader(dataset=train_data) >>> next(iter(train_dataloader))
MRGCN Dataset Loader¶
- class imml.load.MRGCNDataset(Xs: list, transform=None)[source]¶
Bases:
objectThis class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with MRGCN.
- Parameters:
See also
Example
>>> import numpy as np >>> import torch >>> from imml.load import MRGCNDataset >>> Xs = [torch.from_numpy(np.random.default_rng(42).random((20, 10))) for i in range(3)] >>> train_data = MRGCNDataset(Xs=Xs)
MUSE Dataset Loader¶
- class imml.load.MUSEDataset(Xs, y)[source]¶
Bases:
objectThis class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with MUSE.
- Parameters:
Xs (list of array-likes objects) --
Xs length: n_mods
A list of different modalities.
y (array-like of shape (n_samples,)) -- Target vector relative to X.
- Returns:
Xs_idx (list of array-likes objects) --
Xs length: n_mods
A list of different modalities for one sample.
y_idx (array-like of shape (n_samples,)) -- Target vector relative to the sample.
missing_mod_indicator (array-like of shape (1, n_mods)) -- Boolean array-like indicating missing modalities for the sample.
y_indicator (array-like of shape (1,)) -- Boolean array-like indicating observed label for the sample.
See also
Example
>>> import numpy as np >>> import pandas as pd >>> import torch >>> from torch.utils.data import DataLoader >>> from imml.load import MUSEDataset >>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)] >>> y = torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float() >>> train_data = MUSEDataset(Xs=Xs, y=y) >>> train_dataloader = DataLoader(dataset=train_data) >>> next(iter(train_dataloader))
RAGPT Dataset Loader¶
- class imml.load.RAGPTDataset(Xs: list, y, mcr=None, Xs_bank: list = None, y_bank=None, batch_size: int = 64, n_neighbors: int = 20, device: str = 'cpu', modalities: list = None, pretrained_model=None, processor=None, prompt_path: str = None, pretrained_vilt=None, tokenizer=None, image_processor=None, max_text_len: int = 40, max_image_len: int = 145)[source]¶
Bases:
objectThis class provides a torch.utils.data.Dataset implementation for handling multi-modal datasets with RAGPT. If it is used with torch.utils.data.DataLoader, the collate_fn argument of the DataLoader constructor should be RAGPTCollator.
- Parameters:
batch_size (int, default=64) -- Batch size used for encoding inputs during memory bank creation and inference.
n_neighbors (int, default=20) -- Number of neighbors to retrieve per sample during prediction.
device (str, default="cpu") -- Device to use for model inference, typically "cpu" or "cuda".
modalities (list of str, default=None) -- Names of the modalities. Options are "text" and "image".
pretrained_model (transformers.PreTrainedModel, default=None) -- A pretrained HuggingFace model used for encoding multimodal inputs (e.g., CLIP model). If None, defaults to "openai/clip-vit-large-patch14-336".
processor (transformers.ProcessorMixin, default=None) -- HuggingFace processor corresponding to the pretrained model. Used to preprocess image/text inputs. If None, defaults to processor for "openai/clip-vit-large-patch14-336".
prompt_path (str, default=None) -- Path to save or load the generated prompts when generate_cap is True.
pretrained_vilt (transformers.PreTrainedModel, default=None) -- Pretrained model used for vision-language prompt generation. If None, defaults to ViltModel.from_pretrained('dandelin/vilt-b32-mlm').
tokenizer (transformers.BertTokenizer, default=None) -- Tokenizer used for text processing. If None, defaults to BertTokenizer.from_pretrained('dandelin/vilt-b32-mlm', do_lower_case=True).
image_processor (transformers.ViltImageProcessor, default=None) -- Image processor used with the ViLT model for image preprocessing. If None, defaults to ViltImageProcessor.from_pretrained('dandelin/vilt-b32-mlm').
max_text_len (int, default=40) -- Maximum token length for text inputs (used during prompt generation). Must not exceed the max_position_embeddings of the ViLT model (default: 40 for 'dandelin/vilt-b32-mlm').
max_image_len (int, default=145) -- Maximum token length for image inputs (used during prompt generation).
- - mcr_
- Type:
Multi-Channel Retriever (MCR) model for image-text retrieval.
- - input_ids_list_
- Type:
Unique identifier for each sample.
- - img_path_list_
- Type:
Path to the image files.
- - attention_mask_list_
- Type:
Attention masks for texts.
- - token_type_ids_list_
- Type:
Token ids for texts.
- - i2i_r_l_list_
- Type:
List of labels of the retrieved items for the image-to-image modality.
- - t2t_r_l_list_
- Type:
List of labels of the retrieved items for the text-to-text modality.
- - label_list_
- Type:
Label of the sample.
- - prompt_image_path_
- Type:
Path to the generated image prompt.
- - prompt_text_path_
- Type:
Path to the generated text prompt.
- - observed_image_
- Type:
Indicator of whether the image was observed.
- - observed_text_
- Type:
Indicator of whether the text was observed.
See also
Example
>>> from torch.utils.data import DataLoader >>> from imml.load import RAGPTDataset >>> Xs = [ pd.DataFrame(["docs/figures/graph.png", "docs/figures/logo_imml.png"]), pd.DataFrame(["This is the graphical abstract of iMML.", "This is the logo of iMML."]), ] >>> y = [0, 1, 0, 1] >>> modalities = ["image", "text"] >>> tmp_path = tempfile.mkdtemp() >>> train_data = RAGPTDataset(Xs=Xs, y=y, Xs_bank=Xs, y_bank=y, modalities=modalities, n_neighbors=1, prompt_path=str(tmp_path))