from itertools import combinations
import numpy as np
from cvxpy import SCS
from scipy.special import rel_entr
from sklearn import preprocessing
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
import cvxpy as cp
from ..utils import check_Xs
[docs]
def pid(Xs, y, n_clusters = 20, n_components = .95, random_state = None, normalize: bool = False,
return_index: bool = False):
r"""
Quantify the degree of redundancy, uniqueness, and synergy (PID statistics) relating input modalities with an
output task. [#pidpaper]_ [#pidcode]_
Parameters
----------
Xs : list of array-like
- Xs length: n_mods
- Xs[i] shape: (n_samples, n_features_i)
A list of different modalities.
y : array-like of shape (n_samples,)
Target vector relative to Xs.
n_clusters : int or list of int, default=20
The number of clusters to generate. If an int, the same value is used for each modality. If a list, it should
contain integers specifying the number of clusters for each modality.
n_components : float, int, list, or None, default=0.95
PCA dimensionality reduction setting per modality:
- If float in (0, 1], it is treated as the fraction of variance to preserve (PCA(n_components=float)).
- If int >= 1, it specifies the exact number of components.
- If None, PCA is skipped for that modality.
If a single value is provided, it is used for all modalities. If a list is provided, it should
contain the value for each modality. Before PCA, features are normalized.
random_state : int, default=None
Determines the randomness. Use an int to make the randomness deterministic.
normalize : bool, default=False
If True, the resulting PID components for each pair are normalized to sum to 1. If False, raw values are
returned.
return_index : bool, default=False
If True, also return the indices of the pairs.
Returns
-------
rus : dict or list of dicts with the following keys:
- 'redundancy': float
- 'unique1': float (unique information from the first modality in the pair)
- 'unique2': float (unique information from the second modality in the pair)
- 'synergy': float
If len(Xs) > 2, returns a list of such dicts, one per unordered modality pair in combinations(Xs, 2), in the
same order as generated by itertools.combinations.
indices : list, optional
The indices of the pairs. Only provided if return_index is True.
References
----------
.. [#pidpaper] Liang, Paul Pu, et al. "Quantifying & modeling multimodal interactions: An information
decomposition framework." Advances in Neural Information Processing Systems 36 (2023):
27351-27393.
.. [#pidcode] https://github.com/pliang279/PID/tree/main
Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> from imml.statistics import pid
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for _ in range(3)]
>>> y = np.random.default_rng(42).choice(2, size=len(Xs[0]))
>>> pid(Xs=Xs, y=y, n_clusters=20, n_components=0.95, random_state=42)
"""
Xs = check_Xs(Xs=Xs)
if not isinstance(n_clusters, list):
n_clusters = [n_clusters]
if not isinstance(n_components, list):
n_components = [n_components]
if len(n_clusters) == 1:
n_clusters = n_clusters * len(Xs)
if len(n_components) == 1:
n_components = n_components * len(Xs)
if len(n_clusters) != len(Xs):
raise ValueError(f"Invalid n_clusters. n_clusters must have the same length as Xs."
f" {len(n_clusters)} clusters for {len(Xs)} modalities were passed.")
if len(n_components) != len(Xs):
raise ValueError(f"Invalid n_components. n_components must have the same length as Xs."
f" {len(n_components)} clusters for {len(Xs)} modalities were passed.")
if len(Xs) < 2:
raise ValueError(f"Invalid Xs. At least two modalities are required. Xs with {len(Xs)} modalities were passed.")
if not isinstance(normalize, bool):
raise ValueError(f"Invalid normalize. normalize must be a boolean. type{normalize} was passed.")
if not isinstance(return_index, bool):
raise ValueError(f"Invalid return_index. return_index must be a boolean. type{return_index} was passed.")
y = LabelEncoder().fit_transform(y)
total_idxs = []
rus = []
for Xs_pair in combinations(list(range(len(Xs))), 2):
labels = []
for X_idx in Xs_pair:
X, clusters, components = Xs[X_idx], n_clusters[X_idx], n_components[X_idx]
if components is not None:
X = preprocessing.normalize(X)
X = PCA(n_components=components, random_state=random_state).fit_transform(X)
kmeans = KMeans(n_clusters=clusters, random_state=random_state).fit(X)
labels.append(kmeans.labels_)
joint_distribution = np.zeros((len(np.unique(labels[0])), len(np.unique(labels[1])), len(np.unique(y))))
for i in range(len(labels[0])):
joint_distribution[labels[0][i], labels[1][i], y[i]] += 1
joint_distribution /= np.sum(joint_distribution)
Q = _solve_Q_new(joint_distribution)
rus_pair = {
'Redundancy':CoI(Q),
'Uniqueness1':_UI(Q, cond_id=1),
'Uniqueness2':_UI(Q, cond_id=0),
'Synergy':_CI(joint_distribution, Q)
}
if normalize:
rus_pair = {key: value/sum(rus_pair.values()) for key,value in rus_pair.items()}
rus.append(rus_pair)
total_idxs.append(Xs_pair)
if len(rus) == 1:
rus = rus[0]
if return_index:
return rus, total_idxs
return rus
def _CI(P, Q):
assert P.shape == Q.shape
P_ = P.transpose([2, 0, 1]).reshape((P.shape[2], P.shape[0]*P.shape[1]))
Q_ = Q.transpose([2, 0, 1]).reshape((Q.shape[2], Q.shape[0]*Q.shape[1]))
return _MI(P_) - _MI(Q_)
def _MI(P: np.ndarray):
''' P has 2 dimensions '''
margin_1 = P.sum(axis=1)
margin_2 = P.sum(axis=0)
outer = np.outer(margin_1, margin_2)
return np.sum(rel_entr(P, outer))
# return np.sum(P * np.log(P/outer))
def CoI(P:np.ndarray):
''' P has 3 dimensions, in order X1, X2, Y '''
# MI(Y; X1)
A = P.sum(axis=1)
# MI(Y; X2)
B = P.sum(axis=0)
# MI(Y; (X1, X2))
C = P.transpose([2, 0, 1]).reshape((P.shape[2], P.shape[0]*P.shape[1]))
return _MI(A) + _MI(B) - _MI(C)
def _UI(P, cond_id=0):
''' P has 3 dimensions, in order X1, X2, Y
We condition on X1 if cond_id = 0, if 1, then X2.
'''
P_ = np.copy(P)
sum = 0.
if cond_id == 0:
J= P.sum(axis=(1,2)) # marginal of x1
for i in range(P.shape[0]):
sum += _MI(P[i, :, :] / P[i, :, :].sum()) * J[i]
elif cond_id == 1:
J= P.sum(axis=(0,2)) # marginal of x1
for i in range(P.shape[1]):
sum += _MI(P[:,i, :] / P[:,i, :].sum()) * J[i]
else:
assert False
return sum
def _solve_Q_new(P: np.ndarray):
'''
Compute optimal Q given 3d array P
with dimensions coressponding to x1, x2, and y respectively
'''
Py = P.sum(axis=0).sum(axis=0)
Px1 = P.sum(axis=1).sum(axis=1)
Px2 = P.sum(axis=0).sum(axis=1)
Px2y = P.sum(axis=0)
Px1y = P.sum(axis=1)
Px1y_given_x2 = P / P.sum(axis=(0, 2), keepdims=True)
Q = [cp.Variable((P.shape[0], P.shape[1]), nonneg=True) for i in range(P.shape[2])]
Q_x1x2 = [cp.Variable((P.shape[0], P.shape[1]), nonneg=True) for i in range(P.shape[2])]
# Constraints that conditional distributions sum to 1
sum_to_one_Q = cp.sum([cp.sum(q) for q in Q]) == 1
# Brute force constraints #
# [A]: p(x1, y) == q(x1, y)
# [B]: p(x2, y) == q(x2, y)
# Adding [A] constraints
A_cstrs = []
for x1 in range(P.shape[0]):
for y in range(P.shape[2]):
vars = []
for x2 in range(P.shape[1]):
vars.append(Q[y][x1, x2])
A_cstrs.append(cp.sum(vars) == Px1y[x1, y])
# Adding [B] constraints
B_cstrs = []
for x2 in range(P.shape[1]):
for y in range(P.shape[2]):
vars = []
for x1 in range(P.shape[0]):
vars.append(Q[y][x1, x2])
B_cstrs.append(cp.sum(vars) == Px2y[x2, y])
# KL divergence
Q_pdt_dist_cstrs = [cp.sum(Q) / P.shape[2] == Q_x1x2[i] for i in range(P.shape[2])]
# objective
obj = cp.sum([cp.sum(cp.rel_entr(Q[i], Q_x1x2[i])) for i in range(P.shape[2])])
all_constrs = [sum_to_one_Q] + A_cstrs + B_cstrs + Q_pdt_dist_cstrs
prob = cp.Problem(cp.Minimize(obj), all_constrs)
try:
prob.solve(verbose=False, max_iters=10000)
except:
prob.solve(solver=SCS, verbose=False, max_iters=10000)
return np.stack([q.value for q in Q], axis=2)