Source code for imml.visualize.plot_summary

# License: BSD-3-Clause

import pandas as pd

from ..explore import get_summary


[docs] def plot_summary(Xs: list = None, summary: pd.DataFrame = None, mod_names: list = None, figsize: tuple = None, title: str = None, xlabel: str = None, ylabel: str = "Count"): r""" Plot a bar chart summarizing completeness across modalities in a multi-modal dataset. Parameters ---------- Xs : list of array-like objects, default=None - Xs length: n_mods - Xs[i] shape: (n_samples, n_features_i) A list of different modalities. Only used when ``summary`` is not provided. summary : pd.DataFrame, default=None A summary dataframe as returned by ``imml.explore.get_summary``. If provided, it will be plotted directly. If None, the summary will be computed from ``Xs``. mod_names : list, default=None Names of each modality to use when computing the summary from ``Xs``. If ``None``, it will default to the modality index. figsize : tuple, default=None Figure size in inches passed to ``pd.DataFrame.plot``. title : str, default="Summary of the multi-modal dataset" Title of the plot. xlabel : str, default="Samples" Label for the x-axis. ylabel : str, default="Count" Label for the y-axis. Returns ------- matplotlib.axes.Axes The matplotlib Axes containing the bar plot. See Also -------- :class:`~imml.explore.get_summary` :class:`~imml.visualize.plot_missing_modality` :class:`~imml.visualize.plot_combinations` Example -------- >>> import numpy as np >>> import pandas as pd >>> from imml.visualize import plot_summary >>> from imml.ampute import Amputer >>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)] >>> Xs = Amputer(p=0.3, random_state=42).fit_transform(Xs) >>> plot_summary(Xs = Xs) """ if (figsize is not None) and (not isinstance(figsize, tuple)): raise ValueError(f"Invalid figsize. It must be a tuple. A {type(figsize)} was passed.") if summary is None: summary = get_summary(Xs=Xs, mod_names=mod_names, compute_pct=False, return_df=True) if not isinstance(summary, pd.DataFrame): raise ValueError(f"Invalid summary. It should be a pd.DataFrame. A {type(summary)} was passed. ") summary.index = summary.index.str.replace(" samples", "") ax = summary[[c for c in summary.columns if not c.startswith('%')]].plot( kind="bar", xlabel=xlabel, ylabel=ylabel, rot=0, title=title, figsize=figsize) return ax