Statistics

PID

class imml.statistics.pid(Xs, y, n_clusters=20, n_components=0.95, random_state=None, normalize: bool = False, return_index: bool = False)[source]

Bases:

Quantify the degree of redundancy, uniqueness, and synergy (PID statistics) relating input modalities with an output task. [1] [2]

Parameters:
  • Xs (list of array-like) --

    • Xs length: n_mods

    • Xs[i] shape: (n_samples, n_features_i)

    A list of different modalities.

  • y (array-like of shape (n_samples,)) -- Target vector relative to Xs.

  • n_clusters (int or list of int, default=20) -- The number of clusters to generate. If an int, the same value is used for each modality. If a list, it should contain integers specifying the number of clusters for each modality.

  • n_components (float, int, list, or None, default=0.95) -- PCA dimensionality reduction setting per modality: - If float in (0, 1], it is treated as the fraction of variance to preserve (PCA(n_components=float)). - If int >= 1, it specifies the exact number of components. - If None, PCA is skipped for that modality. If a single value is provided, it is used for all modalities. If a list is provided, it should contain the value for each modality. Before PCA, features are normalized.

  • random_state (int, default=None) -- Determines the randomness. Use an int to make the randomness deterministic.

  • normalize (bool, default=False) -- If True, the resulting PID components for each pair are normalized to sum to 1. If False, raw values are returned.

  • return_index (bool, default=False) -- If True, also return the indices of the pairs.

Returns:

  • rus (dict or list of dicts with the following keys:) --

    • 'redundancy': float

    • 'unique1': float (unique information from the first modality in the pair)

    • 'unique2': float (unique information from the second modality in the pair)

    • 'synergy': float

    • 'total': float

    If len(Xs) > 2, returns a list of such dicts, one per unordered modality pair in combinations(Xs, 2), in the same order as generated by itertools.combinations.

  • indices (list, optional) -- The indices of the pairs. Only provided if return_index is True.

References

See also

plot_pid

Example

>>> import numpy as np
>>> import pandas as pd
>>> from imml.statistics import pid
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for _ in range(3)]
>>> y = np.random.default_rng(42).choice(2, size=len(Xs[0]))
>>> pid(Xs=Xs, y=y, n_clusters=20, n_components=0.95, random_state=42)