Source code for imml.load.m3care_dataset

# License: BSD-3-Clause
import numpy as np
import pandas as pd

from ..impute import get_observed_mod_indicator
from ..utils import check_Xs_y
from .. import deepmodule_installed, deepmodule_error, Dataset

if deepmodule_installed:
    import torch


[docs] class M3CareDataset(Dataset): r""" This class provides a `torch.utils.data.Dataset` implementation for handling multi-modal datasets with `M3Care`. Parameters ---------- Xs : list of array-likes objects - Xs length: n_mods A list of different modalities. y : array-like of shape (n_samples,) Target vector relative to X. Returns ------- Xs: list of array-likes objects - Xs length: n_mods A list of different modalities for one sample. y: array-like of shape (n_samples,) Target vector relative to the sample. observed_mod_indicator: array-like of shape (1, n_mods) Boolean array-like indicating observed modalities for the sample. See Also -------- :class:`~imml.classify.M3Care` Example -------- >>> from torch.utils.data import DataLoader >>> import numpy as np >>> import pandas as pd >>> from imml.load import M3CareDataset >>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)] >>> y = torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float() >>> train_data = M3CareDataset(Xs=Xs, y=y) >>> train_dataloader = DataLoader(dataset=train_data) >>> next(iter(train_dataloader)) """ def __init__(self, Xs, y): if not deepmodule_installed: raise ImportError(deepmodule_error) Xs = check_Xs_y(Xs=Xs, y=y, supervised=True) observed_mod_indicator = get_observed_mod_indicator(Xs) if isinstance(observed_mod_indicator, pd.DataFrame): observed_mod_indicator = observed_mod_indicator.values if isinstance(observed_mod_indicator, np.ndarray): observed_mod_indicator = torch.from_numpy(observed_mod_indicator) observed_mod_indicator = observed_mod_indicator.bool() Xs_ = [] for X in Xs: if isinstance(Xs[0], pd.DataFrame): X = X.values if isinstance(X, np.ndarray): if X[:,0].dtype == object: X = X.tolist() X = [sent if isinstance(sent[0], str) else [""] for sent in X] else: X = torch.from_numpy(X).float() Xs_.append(X) if isinstance(y, (pd.DataFrame, pd.Series)): y = y.values if isinstance(y, np.ndarray): y = torch.from_numpy(y) self.Xs = Xs_ self.y = y self.observed_mod_indicator = observed_mod_indicator def __len__(self): return len(self.observed_mod_indicator) def __getitem__(self, idx): Xs = [X[idx][0] if isinstance(X[idx][0], str) else X[idx] for X in self.Xs] sample = Xs, self.y[idx], self.observed_mod_indicator[idx] return sample