Source code for imml.load.mrgcn_dataset

# License: BSD-3-Clause

from ..utils import check_Xs_y
from .. import deepmodule_installed, deepmodule_error, Dataset

if deepmodule_installed:
    import torch


[docs] class MRGCNDataset(Dataset): r""" This class provides a `torch.utils.data.Dataset` implementation for handling multi-modal datasets with `MRGCN`. Parameters ---------- Xs : list of array-likes objects - Xs length: n_mods - Xs[i] shape: (n_samples, n_features_i) A list of different modalities. transform : list of callable, defult=None A list of functions or transformations to apply to each sample in the dataset. See Also -------- :class:`~imml.cluster.MRGCN` Example -------- >>> import numpy as np >>> import torch >>> from imml.load import MRGCNDataset >>> Xs = [torch.from_numpy(np.random.default_rng(42).random((20, 10))) for i in range(3)] >>> train_data = MRGCNDataset(Xs=Xs) """ def __init__(self, Xs: list, transform = None): if not deepmodule_installed: raise ImportError(deepmodule_error) Xs = check_Xs_y(Xs=Xs) self.Xs = Xs self.transform = transform def __len__(self): return len(self.Xs[0]) def __getitem__(self, idx): if self.transform is not None: Xs = [self.transform[idx](X[idx]) for X in self.Xs] else: Xs = [X[idx] for X in self.Xs] Xs = tuple(Xs) return Xs