Source code for imml.visualize.plot_missing_modality

# License: BSD-3-Clause

import matplotlib
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from ..impute import get_observed_mod_indicator
from ..utils import check_Xs_y


[docs] def plot_missing_modality(Xs, ax: matplotlib.axes.Axes = None, figsize: tuple = None, sort: bool = True): r""" Plot modality missing. Missing modalities appear as white, while black indicates available modalities. Parameters ---------- Xs : list of array-likes objects, default=None - Xs length: n_mods - Xs[i] shape: (n_samples, n_features_i) A list of different modalities. If rus is provided, it will not be used. ax : matplotlib.axes.Axes, default=None Axes where to draw the figure. figsize : tuple, default=None Figure size (tuple) in inches. sort : bool, default=True If True, samples will be sort based on their available modalities. Returns ------- fig : `matplotlib.figure.Figure` Figure object. ax : `matplotlib.axes.Axes` Axes object. See Also -------- :class:`~imml.visualize.plot_summary` :class:`~imml.visualize.plot_combinations` Example -------- >>> import numpy as np >>> import pandas as pd >>> from imml.ampute import Amputer >>> from imml.visualize import plot_missing_modality >>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)] >>> transformer = Amputer(p= 0.2, random_state=42) >>> Xs = transformer.fit_transform(Xs) >>> plot_missing_modality(Xs=Xs) """ Xs = check_Xs_y(Xs=Xs) if (ax is not None) and (not isinstance(ax, matplotlib.axes.Axes)): raise ValueError(f"Invalid ax. It must be a matplotlib.axes.Axes. A {type(ax)} was passed.") if (figsize is not None) and (not isinstance(figsize, tuple)): raise ValueError(f"Invalid figsize. It must be a tuple. A {type(figsize)} was passed.") if not isinstance(sort, bool): raise ValueError(f"Invalid sort. It must be a bool. A {type(sort)} was passed.") if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) else: fig = None xlabel, ylabel = "Modality", "Samples" observed_view_indicator = get_observed_mod_indicator(Xs) observed_view_indicator = pd.DataFrame(observed_view_indicator) if sort: observed_view_indicator = observed_view_indicator.sort_values(list(range(len(Xs)))) observed_view_indicator.columns = observed_view_indicator.columns + 1 ax.pcolor(observed_view_indicator, cmap="binary", edgecolors="black", vmin=0., vmax=2.) ax.set_xticks(np.arange(0.5, len(observed_view_indicator.columns), 1), observed_view_indicator.columns) _ = ax.set_xlabel(xlabel), ax.set_ylabel(ylabel) return fig, ax