# License: BSD-3-Clause
from itertools import combinations
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from ..utils import check_Xs_y
[docs]
def plot_combinations(Xs: list, mod_names: list = None, figsize: tuple = None, max_combs: int = 10):
r"""
Plot the number of samples per modality combination.
This function summarizes how many samples are present in each intersection of modalities
(i.e., samples that are available simultaneously across two or more modalities). The
resulting figure is similar to an UpSet plot, but it displays the exact counts of
intersections.
Parameters
----------
Xs : list of array-like objects, default=None
- Xs length: n_mods
- Xs[i] shape: (n_samples, n_features_i)
A list of different modalities. Only used when ``summary`` is not provided.
mod_names : list of str, default=None
Names of each modality (length must match ``len(Xs)``). If ``None``, modality names
default to indices: ``["0", "1", ...]``.
figsize : tuple, default=None
Figure size in inches passed to ``matplotlib.pyplot.subplots``.
max_combs : int, default=10
Maximum number of intersections to display. If fewer intersections are available, all will be shown.
Returns
-------
fig : matplotlib.figure.Figure
The created matplotlib Figure.
axes : numpy.ndarray of matplotlib.axes.Axes
2 x 2 array of Axes as described in the layout above.
See Also
--------
:class:`~imml.visualize.plot_missing_modality`
:class:`~imml.visualize.plot_summary`
Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> from imml.visualize import plot_combinations
>>> from imml.ampute import Amputer
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for _ in range(3)]
>>> Xs = Amputer(p=0.3, random_state=42).fit_transform(Xs)
>>> fig, axes = plot_combinations(Xs=Xs, mod_names=['RNA', 'Protein', 'Metabolite'], max_combs=8)
"""
Xs = check_Xs_y(Xs=Xs, modalities=mod_names)
if (figsize is not None) and (not isinstance(figsize, tuple)):
raise ValueError(f"Invalid figsize. It must be a tuple. A {type(figsize)} was passed.")
if not isinstance(max_combs, int):
raise ValueError(f"Invalid max_combs. It must be a bool. A {type(max_combs)} was passed.")
if mod_names is None:
mod_names = [str(i) for i in range(len(Xs))]
if not isinstance(Xs[0], pd.DataFrame):
Xs = [pd.DataFrame(X) for X in Xs]
Xs = [X.loc[~X.isna().all(1)] for X in Xs]
common_indices = {}
for size in range(2, len(Xs) + 1):
for combo in combinations(range(len(Xs)), size):
common_idx = Xs[combo[0]].index
for idx in combo[1:]:
common_idx = common_idx.intersection(Xs[idx].index)
combo = tuple([mod_names[i] for i in combo])
common_indices[combo] = len(common_idx)
fig, axes = plt.subplots(2, 2, constrained_layout=True, figsize=figsize, gridspec_kw={'width_ratios': [1, 2]})
ax = axes[0, 0]
ax.axis("off")
combs = pd.DataFrame(common_indices.keys())
combs["size"] = common_indices.values()
combs = combs.sort_values(by="size", ascending=False)
combs = pd.concat([combs.iloc[:(max_combs-1)], combs.loc[~combs.isna().any(axis=1)]])
ax = axes[0, 1]
ax = combs.plot(kind="bar", y="size", ax=ax, ylabel="Intersection size", color="black", legend=False)
ax.get_xaxis().set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
for container in ax.containers:
ax.bar_label(container)
ax = axes[1, 0]
mod_counts = pd.Series([len(X) for X in Xs], index=mod_names).sort_values(ascending=True)
selected_combs = [list(i) for i in combs.drop(columns="size").values]
selected_combs = np.unique([i for comb in selected_combs for i in comb if i is not None])
mod_counts = mod_counts.loc[[id for id in mod_counts.index if id in selected_combs]]
mod_names = mod_counts.index.tolist()
cat_to_y = {cat: i for i, cat in enumerate(mod_names)}
ax = mod_counts.plot(kind="barh", ax=ax, xlabel="Set size", color="black")
ax.invert_xaxis()
ax.get_yaxis().set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
for container in ax.containers:
ax.bar_label(container, padding=2)
ax = axes[1, 1]
combs = combs.drop(columns="size").reset_index(drop=True)
combs = combs.iloc[::-1]
with plt.style.context('seaborn-v0_8-darkgrid'):
for col in combs.columns:
current_col = combs.reset_index()
current_col[col] = current_col[col].astype("category")
current_col = current_col[current_col[col].isin(mod_names)]
current_col[col] = current_col[col].map(cat_to_y)
ax = current_col.plot(kind="scatter", ax=ax, x="index", y=col, s=200, ylabel="", c="black")
ax.get_xaxis().set_visible(False)
ax.set_xlim(axes[0,1].get_xlim())
ax.set_ylim(axes[1,0].get_ylim())
ax.set_yticks(list(cat_to_y.values()))
ax.set_yticklabels(mod_names)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
return fig, axes