# License: BSD-3-Clause
import numpy as np
import pandas as pd
import snf
from sklearn.cluster import spectral_clustering, SpectralClustering
from sklearn.manifold import spectral_embedding
from sklearn.utils import check_symmetric
from snf.compute import _find_dominate_set
from ._integrao._aux_integrao import data_indexing, dist2, _stable_normalized_pd, _scaling_normalized_pd, p_preprocess, \
_stable_normalized
from ..preprocessing import remove_missing_samples_by_mod
try:
import torch
import lightning as L
from torch import nn, optim, autograd
from torch_geometric.nn import GraphSAGE
deepmodule_installed = True
except ImportError:
deepmodule_installed = False
deepmodule_error = "Module 'deep' needs to be installed. See https://imml.readthedocs.io/stable/main/installation.html#optional-dependencies"
LightningModuleBase = L.LightningModule if deepmodule_installed else object
nnModuleBase = nn.Module if deepmodule_installed else object
[docs]
class IntegrAO(LightningModuleBase):
r"""
Integrate Any Omics (IntegrAO). [#integraopaper]_ [#integraocode]_
IntegrAO first combines partially overlapping sample graphs from diverse sources and utilizes graph neural
networks to produce unified sample embeddings.
This class provides training, validation, testing, and prediction logic compatible with the
`Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
Parameters
----------
Xs : list of array-likes objects
- Xs length: n_mods
- Xs[i] shape: (n_samples, n_features_i)
A list of different modalities. It will be used to create the neural network architecture.
model : nn.Module, default=None
Deep learning model. If None, it will select IntegrAOModule.
n_clusters : int, default=8
The number of clusters to generate.
neighbor_size : int, default=None
Number of neighbors to use. If None, it will use N/6).
hidden_channels : int, default=128
Hidden dimension size.
embedding_dims : int, default=50
Size of the shared embedding space where modalities are projected.
fusing_iteration : int, default=20
Number of iterations for fusing.
mu : float, default=0.5
Normalization factor to scale similarity kernel.
learning_rate : float, default=1e-3
Learning rate for the optimizer.
weight_decay : float, default=2e-2
Weight decay used by the optimizer.
random_state : int, default=None
Determines the randomness. Use an int to make the randomness deterministic.
Attributes
----------
embedding_ : array-like of shape (n_samples, n_clusters)
Commont latent feature matrix.
cluster_model_ : SpectralClustering
Scikit-learn SpectralClustering object.
fused_networks_ : list of array-like of shape (n_samples_i, n_samples_i)
Modal-specific graphs.
References
----------
.. [#integraopaper] Ma, Shihao, et al. "Moving towards genome-wide data integration for patient stratification
with Integrate Any Omics." Nature Machine Intelligence 7.1 (2025): 29-42.
.. [#integraocode] https://github.com/bowang-lab/IntegrAO
Example
--------
>>> import numpy as np
>>> import torch
>>> from imml.cluster import IntegrAO
>>> from lightning import Trainer
>>> from torch.utils.data import DataLoader
>>> from imml.load import IntegrAODataset
>>> Xs = [torch.from_numpy(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> estimator = IntegrAO(Xs=Xs, random_state=42)
>>> train_data = IntegrAODataset(Xs=Xs, neighbor_size=estimator.neighbor_size, networks=estimator.fused_networks_)
>>> train_dataloader = DataLoader(dataset=train_data)
>>> trainer = Trainer(max_epochs=2, logger=False, enable_checkpointing=False)
>>> trainer.fit(estimator, train_dataloader)
>>> labels = trainer.predict(estimator, train_dataloader)[0]
"""
def __init__(self, Xs, model : nnModuleBase = None, n_clusters: int = 8, neighbor_size : int = None,
hidden_channels : int = 128, embedding_dims: int = 50, fusing_iteration: int = 20,
mu : float = 0.5, learning_rate : float = 1e-3, weight_decay : float = 1e-4, random_state : int = None):
if not deepmodule_installed:
raise ImportError(deepmodule_error)
super().__init__()
if not isinstance(n_clusters, int):
raise ValueError(f"Invalid n_clusters. It must be an int. A {type(n_clusters)} was passed.")
if n_clusters < 2:
raise ValueError(f"Invalid n_clusters. It must be an greater than 1. {n_clusters} was passed.")
if not isinstance(Xs, list):
raise ValueError(f"Invalid Xs. It must be a list of array-likes objects objects. A {type(Xs)} was passed.")
if not isinstance(learning_rate, float):
raise ValueError(f"Invalid learning_rate. It must be a float. A {type(learning_rate)} was passed.")
if learning_rate <= 0:
raise ValueError(f"Invalid learning_rate. It must be a positive number. {learning_rate} was passed.")
if not isinstance(Xs[0], pd.DataFrame):
Xs = [pd.DataFrame(X) for X in Xs]
if neighbor_size is None:
neighbor_size = int(Xs[0].shape[0]/6)
self.neighbor_size = neighbor_size
self.n_clusters = n_clusters
self.embedding_dims = embedding_dims
self.fusing_iteration = fusing_iteration
self.mu = mu
self.hidden_channels = hidden_channels
self.loss_mse = nn.MSELoss()
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.random_state = random_state
Xs = remove_missing_samples_by_mod(Xs=Xs)
(
self.dicts_common,
self.dicts_commonIndex,
self.dict_sampleToIndexs,
self.dicts_unique,
self.original_order,
self.dict_original_order,
) = data_indexing(Xs)
self._network_diffusion(Xs=Xs)
if model is None:
model = IntegrAOModule(in_channels=[X.shape[1] for X in Xs], hidden_channels=hidden_channels,
out_channels=embedding_dims)
self.model = model
ps = []
for network in self.fused_networks_:
p = p_preprocess(network)
p = torch.from_numpy(p).float()
ps.append(p)
self.ps = ps
def forward(self, Xs, average=True):
return self.model(Xs, average=average)
[docs]
def training_step(self, batch, batch_idx=None):
r"""
Method required for training using `Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
"""
embeddings = self(Xs=batch, average=False)
kl_loss = sum([self._tsne_loss(self.ps[i], X_embedding) for i,X_embedding in enumerate(embeddings)])
alignment_loss = sum([self.loss_mse(
embeddings[i][self.dicts_commonIndex[(i, j)]], embeddings[j][self.dicts_commonIndex[(j, i)]])
for i in range(len(embeddings) - 1) for j in range(i + 1, len(embeddings))
])
loss = kl_loss + alignment_loss
return loss
[docs]
def validation_step(self, batch, batch_idx=None):
r"""
Method required for validating using `Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
"""
embeddings = self(Xs=batch, average=False)
kl_loss = sum([self._tsne_loss(self.ps[i], X_embedding) for i,X_embedding in enumerate(embeddings)])
alignment_loss = sum([self.loss_mse(
embeddings[i][self.dicts_commonIndex[(i, j)]], embeddings[j][self.dicts_commonIndex[(j, i)]])
for i in range(len(embeddings) - 1) for j in range(i + 1, len(embeddings))
])
loss = kl_loss + alignment_loss
return loss
[docs]
def test_step(self, batch, batch_idx=None):
r"""
Method required for testing using `Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
"""
embeddings = self(Xs=batch, average=False)
kl_loss = sum([self._tsne_loss(self.ps[i], X_embedding) for i,X_embedding in enumerate(embeddings)])
alignment_loss = sum([self.loss_mse(
embeddings[i][self.dicts_commonIndex[(i, j)]], embeddings[j][self.dicts_commonIndex[(j, i)]])
for i in range(len(embeddings) - 1) for j in range(i + 1, len(embeddings))
])
loss = kl_loss + alignment_loss
return loss
[docs]
def predict_step(self, batch, batch_idx=None):
r"""
Method required for predicting using `Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
"""
embeddings = self(Xs=batch)
embeddings = pd.DataFrame(data=embeddings, index=self.dict_sampleToIndexs.keys()).sort_index().values
dist_final = dist2(embeddings, embeddings)
Wall_final = snf.compute.affinity_matrix(dist_final, K=self.neighbor_size, mu=self.mu)
Wall_final = _stable_normalized(Wall_final)
if getattr(self, "cluster_model_", None) is None:
self.cluster_model_ = SpectralClustering(n_clusters=self.n_clusters, random_state=self.random_state,
affinity="precomputed")
labels = self.cluster_model_.fit_predict(X=Wall_final)
self.embedding_ = spectral_embedding(self.cluster_model_.affinity_matrix_, n_components=self.n_clusters,
eigen_solver=self.cluster_model_.eigen_solver,
random_state=self.random_state,
eigen_tol=self.cluster_model_.eigen_tol, drop_first=False)
return labels
def _network_diffusion(self, Xs):
S_dfs = []
for X_idx, X in enumerate(Xs):
dist_mat = dist2(X.values, X.values)
S_mat = snf.compute.affinity_matrix(dist_mat, K=self.neighbor_size, mu=self.mu)
S_df = pd.DataFrame(data=S_mat, index=self.original_order[X_idx], columns=self.original_order[X_idx])
S_dfs.append(S_df)
self.fused_networks_ = self._integrao_fuse(aff=S_dfs.copy(), dicts_common=self.dicts_common,
dicts_unique=self.dicts_unique, original_order=self.original_order,
neighbor_size=self.neighbor_size,
fusing_iteration=self.fusing_iteration)
def _integrao_fuse(self, aff, dicts_common, dicts_unique, original_order, neighbor_size=20, fusing_iteration=20):
newW = [0] * len(aff)
for n, mat in enumerate(aff):
# normalize affinity matrix based on strength of edges
# mat = mat / np.nansum(mat, axis=1, keepdims=True)
aff[n] = _stable_normalized_pd(mat)
# aff[n] = check_symmetric(mat, raise_warning=False)
# apply KNN threshold to normalized affinity matrix
# We need to crop the intersecting samples from newW matrices
neighbor_size = min(int(neighbor_size), mat.shape[0])
newW[n] = _find_dominate_set(aff[n], neighbor_size)
# If there is only one view, return it
if len(aff) == 1:
print("Only one view, return it directly")
return newW
for iteration in range(fusing_iteration):
# Make a copy of the aff matrix for this iteration
# goal is to update aff[n], but it is the average of all the defused matrices
# Make a copy of add[n], and set it to 0
aff_next = []
for k in range(len(aff)):
aff_temp = aff[k].copy()
for col in aff_temp.columns:
aff_temp[col].values[:] = 0
aff_next.append(aff_temp)
for n, mat in enumerate(aff):
# temporarily convert nans to 0 to avoid propagation errors
nzW = newW[n] # TODO: not sure this is a deep copy or not
for j, mat_tofuse in enumerate(aff):
if n == j:
continue
# reorder mat_tofuse to have the common samples
mat_tofuse = mat_tofuse.reindex(
(sorted(dicts_common[(j, n)]) + sorted(dicts_unique[(j, n)])),
axis=1,
)
mat_tofuse = mat_tofuse.reindex(
(sorted(dicts_common[(j, n)]) + sorted(dicts_unique[(j, n)])),
axis=0,
)
# Next, let's crop mat_tofuse
num_common = len(dicts_common[(n, j)])
to_drop_mat = mat_tofuse.columns[
num_common: mat_tofuse.shape[1]
].values.tolist()
mat_tofuse_crop = mat_tofuse.drop(to_drop_mat, axis=1)
mat_tofuse_crop = mat_tofuse_crop.drop(to_drop_mat, axis=0)
# Next, add the similarity from the view to fused to the current view identity matrix
nzW_identity = pd.DataFrame(
data=np.identity(nzW.shape[0]),
index=original_order[n],
columns=original_order[n],
)
mat_tofuse_union = nzW_identity + mat_tofuse_crop
mat_tofuse_union.fillna(0.0, inplace=True)
mat_tofuse_union = _scaling_normalized_pd(mat_tofuse_union,
ratio=mat_tofuse_crop.shape[0] / nzW_identity.shape[0])
mat_tofuse_union = check_symmetric(mat_tofuse_union, raise_warning=False)
mat_tofuse_union = mat_tofuse_union.reindex(original_order[n], axis=1)
mat_tofuse_union = mat_tofuse_union.reindex(original_order[n], axis=0)
# Now we are ready to do the diffusion
nzW_T = np.transpose(nzW)
aff0_temp = nzW.dot(
mat_tofuse_union.dot(nzW_T)
) # Matmul is not working, but .dot() is good
#################################################
# Experimentally introduce a weighting machanisim, use the exponential weight; Already proved it's not a good idea
# num_com = mat_tofuse_crop.shape[0] / aff[n].shape[0]
# alpha = pow(2, num_com) - 1
# aff0_temp = alpha * aff0_temp + (1-alpha) * aff[n]
# aff0_temp = _B0_normalized(aff0_temp, alpha=normalization_factor)
aff0_temp = _stable_normalized_pd(aff0_temp)
# aff0_temp = check_symmetric(aff0_temp, raise_warning=False)
aff_next[n] = np.add(aff0_temp, aff_next[n])
aff_next[n] = np.divide(aff_next[n], len(aff) - 1)
# aff_next[n] = _stable_normalized_pd(aff_next[n])
# put the value in aff_next back to aff
for k in range(len(aff)):
aff[k] = aff_next[k]
for n, mat in enumerate(aff):
aff[n] = _stable_normalized_pd(mat)
# aff[n] = check_symmetric(mat, raise_warning=False)
aff = [x.values for x in aff]
return aff
@staticmethod
def _tsne_loss(P, activations):
device = P.device
n = activations.size(0)
alpha = 1
eps = 1e-12
sum_act = torch.sum(torch.pow(activations, 2), 1)
Q = (
sum_act
+ sum_act.view([-1, 1])
- 2 * torch.matmul(activations, torch.transpose(activations, 0, 1))
)
Q = Q / alpha
Q = torch.pow(1 + Q, -(alpha + 1) / 2)
Q = Q * autograd.Variable(1 - torch.eye(n), requires_grad=False).to(device)
Q = Q / torch.sum(Q)
Q = torch.clamp(Q, min=eps)
C = torch.log((P + eps) / (Q + eps))
C = torch.sum(P * C)
return C
class IntegrAOModule(nnModuleBase):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.output_dim = out_channels
self.num_layers = num_layers
num = len(in_channels)
feature = []
for i in range(num):
model_sage = GraphSAGE(
in_channels=self.in_channels[i],
hidden_channels=self.hidden_channels,
num_layers=self.num_layers,
out_channels=self.output_dim,
project=False)
feature.append(model_sage)
self.feature = nn.ModuleList(feature)
self.feature_show = nn.Sequential(
nn.Linear(self.output_dim, self.output_dim),
nn.BatchNorm1d(self.output_dim),
nn.LeakyReLU(0.1, True),
nn.Linear(self.output_dim, self.output_dim),
)
def forward(self, Xs, average=True):
z_all = []
xs = Xs[0]
edge_indices = Xs[1]
idxs = Xs[2]
ids = Xs[3].cpu().numpy()
for X_idx, (X,edge_index,idx) in enumerate(zip(xs, edge_indices, idxs)):
X = pd.DataFrame(X.cpu().numpy(), index=ids).loc[idx[0].cpu().numpy()]
X = torch.from_numpy(X.values).type(torch.float32)
z = self.feature[X_idx](X, edge_index[0])
z = self.feature_show(z)
z_all.append(z)
if average:
mean_z = np.zeros((len(np.unique(np.concatenate([idx[0] for idx in idxs]))), z_all[0].shape[1]))
mean_z = pd.DataFrame(mean_z)
ones_z = mean_z.copy()
for X_idx, (X, idx, z) in enumerate(zip(xs, idxs, z_all)):
idx = idx[0].numpy()
mean_z.loc[idx] += z.numpy()
ones_z.loc[idx] += 1
mean_z /= ones_z
z_all = torch.from_numpy(mean_z.values).type(torch.float32)
return z_all