Source code for imml.statistics.pid

# License: BSD-3-Clause

from itertools import combinations
import numpy as np
from cvxpy import SCS
from scipy.special import rel_entr
from sklearn import preprocessing
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
import cvxpy as cp

from ..utils import check_Xs_y


[docs] def pid(Xs, y, n_clusters = 20, n_components = .95, random_state = None, normalize: bool = False, return_index: bool = False): r""" Quantify the degree of redundancy, uniqueness, and synergy (PID statistics) relating input modalities with an output task. [#pidpaper]_ [#pidcode]_ Parameters ---------- Xs : list of array-like - Xs length: n_mods - Xs[i] shape: (n_samples, n_features_i) A list of different modalities. y : array-like of shape (n_samples,) Target vector relative to Xs. n_clusters : int or list of int, default=20 The number of clusters to generate. If an int, the same value is used for each modality. If a list, it should contain integers specifying the number of clusters for each modality. n_components : float, int, list, or None, default=0.95 PCA dimensionality reduction setting per modality: - If float in (0, 1], it is treated as the fraction of variance to preserve (PCA(n_components=float)). - If int >= 1, it specifies the exact number of components. - If None, PCA is skipped for that modality. If a single value is provided, it is used for all modalities. If a list is provided, it should contain the value for each modality. Before PCA, features are normalized. random_state : int, default=None Determines the randomness. Use an int to make the randomness deterministic. normalize : bool, default=False If True, the resulting PID components for each pair are normalized to sum to 1. If False, raw values are returned. return_index : bool, default=False If True, also return the indices of the pairs. Returns ------- rus : dict or list of dicts with the following keys: - 'redundancy': float - 'unique1': float (unique information from the first modality in the pair) - 'unique2': float (unique information from the second modality in the pair) - 'synergy': float - 'total': float If len(Xs) > 2, returns a list of such dicts, one per unordered modality pair in combinations(Xs, 2), in the same order as generated by itertools.combinations. indices : list, optional The indices of the pairs. Only provided if return_index is True. References ---------- .. [#pidpaper] Liang, Paul Pu, et al. "Quantifying & modeling multimodal interactions: An information decomposition framework." Advances in Neural Information Processing Systems 36 (2023): 27351-27393. .. [#pidcode] https://github.com/pliang279/PID/tree/main See Also -------- :class:`~imml.visualize.plot_pid` Example -------- >>> import numpy as np >>> import pandas as pd >>> from imml.statistics import pid >>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for _ in range(3)] >>> y = np.random.default_rng(42).choice(2, size=len(Xs[0])) >>> pid(Xs=Xs, y=y, n_clusters=20, n_components=0.95, random_state=42) """ Xs = check_Xs_y(Xs=Xs) if not isinstance(n_clusters, list): n_clusters = [n_clusters] if not isinstance(n_components, list): n_components = [n_components] if len(n_clusters) == 1: n_clusters = n_clusters * len(Xs) if len(n_components) == 1: n_components = n_components * len(Xs) if len(n_clusters) != len(Xs): raise ValueError(f"Invalid n_clusters. n_clusters must have the same length as Xs." f" {len(n_clusters)} clusters for {len(Xs)} modalities were passed.") if len(n_components) != len(Xs): raise ValueError(f"Invalid n_components. n_components must have the same length as Xs." f" {len(n_components)} clusters for {len(Xs)} modalities were passed.") if not isinstance(normalize, bool): raise ValueError(f"Invalid normalize. normalize must be a boolean. type{normalize} was passed.") if not isinstance(return_index, bool): raise ValueError(f"Invalid return_index. return_index must be a boolean. type{return_index} was passed.") y = LabelEncoder().fit_transform(y) total_idxs = [] rus = [] for Xs_pair in combinations(list(range(len(Xs))), 2): labels = [] for X_idx in Xs_pair: X, clusters, components = Xs[X_idx], n_clusters[X_idx], n_components[X_idx] if components is not None: X = preprocessing.normalize(X) X = PCA(n_components=components, random_state=random_state).fit_transform(X) kmeans = KMeans(n_clusters=clusters, random_state=random_state).fit(X) labels.append(kmeans.labels_) joint_distribution = np.zeros((len(np.unique(labels[0])), len(np.unique(labels[1])), len(np.unique(y)))) for i in range(len(labels[0])): joint_distribution[labels[0][i], labels[1][i], y[i]] += 1 joint_distribution /= np.sum(joint_distribution) Q = _solve_Q_new(joint_distribution) rus_pair = { 'Redundancy':CoI(Q), 'Uniqueness1':_UI(Q, cond_id=1), 'Uniqueness2':_UI(Q, cond_id=0), 'Synergy':_CI(joint_distribution, Q) } total = sum(rus_pair.values()) if normalize: rus_pair = {key: value/sum(rus_pair.values()) for key,value in rus_pair.items()} rus_pair = {"Information": total, **rus_pair} rus.append(rus_pair) total_idxs.append(Xs_pair) if len(rus) == 1: rus = rus[0] if return_index: return rus, total_idxs return rus
def _CI(P, Q): assert P.shape == Q.shape P_ = P.transpose([2, 0, 1]).reshape((P.shape[2], P.shape[0]*P.shape[1])) Q_ = Q.transpose([2, 0, 1]).reshape((Q.shape[2], Q.shape[0]*Q.shape[1])) return _MI(P_) - _MI(Q_) def _MI(P: np.ndarray): ''' P has 2 dimensions ''' margin_1 = P.sum(axis=1) margin_2 = P.sum(axis=0) outer = np.outer(margin_1, margin_2) return np.sum(rel_entr(P, outer)) # return np.sum(P * np.log(P/outer)) def CoI(P:np.ndarray): ''' P has 3 dimensions, in order X1, X2, Y ''' # MI(Y; X1) A = P.sum(axis=1) # MI(Y; X2) B = P.sum(axis=0) # MI(Y; (X1, X2)) C = P.transpose([2, 0, 1]).reshape((P.shape[2], P.shape[0]*P.shape[1])) return _MI(A) + _MI(B) - _MI(C) def _UI(P, cond_id=0): ''' P has 3 dimensions, in order X1, X2, Y We condition on X1 if cond_id = 0, if 1, then X2. ''' P_ = np.copy(P) sum = 0. if cond_id == 0: J= P.sum(axis=(1,2)) # marginal of x1 for i in range(P.shape[0]): sum += _MI(P[i, :, :] / P[i, :, :].sum()) * J[i] elif cond_id == 1: J= P.sum(axis=(0,2)) # marginal of x1 for i in range(P.shape[1]): sum += _MI(P[:,i, :] / P[:,i, :].sum()) * J[i] return sum def _solve_Q_new(P: np.ndarray): ''' Compute optimal Q given 3d array P with dimensions coressponding to x1, x2, and y respectively ''' Py = P.sum(axis=0).sum(axis=0) Px1 = P.sum(axis=1).sum(axis=1) Px2 = P.sum(axis=0).sum(axis=1) Px2y = P.sum(axis=0) Px1y = P.sum(axis=1) Px1y_given_x2 = P / P.sum(axis=(0, 2), keepdims=True) Q = [cp.Variable((P.shape[0], P.shape[1]), nonneg=True) for i in range(P.shape[2])] Q_x1x2 = [cp.Variable((P.shape[0], P.shape[1]), nonneg=True) for i in range(P.shape[2])] # Constraints that conditional distributions sum to 1 sum_to_one_Q = cp.sum([cp.sum(q) for q in Q]) == 1 # Brute force constraints # # [A]: p(x1, y) == q(x1, y) # [B]: p(x2, y) == q(x2, y) # Adding [A] constraints A_cstrs = [] for x1 in range(P.shape[0]): for y in range(P.shape[2]): vars = [] for x2 in range(P.shape[1]): vars.append(Q[y][x1, x2]) A_cstrs.append(cp.sum(vars) == Px1y[x1, y]) # Adding [B] constraints B_cstrs = [] for x2 in range(P.shape[1]): for y in range(P.shape[2]): vars = [] for x1 in range(P.shape[0]): vars.append(Q[y][x1, x2]) B_cstrs.append(cp.sum(vars) == Px2y[x2, y]) # KL divergence Q_pdt_dist_cstrs = [cp.sum(Q) / P.shape[2] == Q_x1x2[i] for i in range(P.shape[2])] # objective obj = cp.sum([cp.sum(cp.rel_entr(Q[i], Q_x1x2[i])) for i in range(P.shape[2])]) all_constrs = [sum_to_one_Q] + A_cstrs + B_cstrs + Q_pdt_dist_cstrs prob = cp.Problem(cp.Minimize(obj), all_constrs) try: prob.solve(verbose=False, max_iters=10000) except: prob.solve(solver=SCS, verbose=False, max_iters=10000) return np.stack([q.value for q in Q], axis=2)