# License: BSD-3-Clause
import numpy as np
from sklearn.model_selection import train_test_split
from ..model_selection._multi_modal_dataset import _MultiModalDataset
from ..utils import check_Xs_y
[docs]
def train_test_mm_split(Xs, y=None, **kwargs):
"""
Split multi-modal datasets and labels into train and test sets.
Similar to sklearn's `train_test_split
<https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html>`_, but works with
lists of arrays/data (Xs) and single arrays (y). Ensures that all X in a Xs get the same train/test split indices.
Parameters
----------
*args : list of array-likes or array-like
Variable number of inputs to split. Can be:
- Lists of arrays (Xs): Multi-modal data where each element is a modality.
- Single arrays (y): Labels.
**kwargs : dict
Additional keyword arguments to pass to sklearn's train_test_split.
Returns
-------
tuple
Splitting results in the same order as inputs:
- For each list input (Xs): (list_train, list_test)
- For each array input (y): (array_train, array_test)
Example
--------
>>> import numpy as np
>>> from imml.model_selection import train_test_mm_split
>>> Xs = [np.random.rand(100, 10), np.random.rand(100, 20)]
>>> y = np.random.randint(0, 2, 100)
>>> Xs_train, Xs_test, y_train, y_test = train_test_mm_split(Xs, y, train_size=0.7, random_state=42)
"""
check_Xs_y(Xs=Xs, y=y)
Xs = _MultiModalDataset(Xs)
idxs = np.arange(len(Xs))
tr, te = train_test_split(idxs, **kwargs)
Xtr = Xs[tr].to_list()
Xte = Xs[te].to_list()
output = (Xtr, Xte)
if y is not None:
if hasattr(y, "iloc"):
ytr = y.iloc[tr]
yte = y.iloc[te]
else:
ytr = y[tr]
yte = y[te]
output = (*output, ytr, yte)
return output