# License: BSD-3-Clause
import operator
import networkx as nx
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator, ClassifierMixin
from ._monet._aux_monet import _best_samples_to_add, _which_sample_to_remove, _which_view_to_add_to_module, \
_which_view_to_remove_from_module, _score_of_split_module, _weight_of_split_and_add_view, \
_weight_of_split_and_remove_view, _weight_of_new_module, _top_samples_to_switch, \
_weight_of_spreading_module, _weight_of_merged_modules, _Globals, _Sample, _Module, _View, _switch_2_samples
from ..utils import check_Xs
from ..preprocessing import remove_missing_samples_by_mod
[docs]
class MONET(BaseEstimator, ClassifierMixin):
r"""
Multi Omic Clustering by Non-Exhaustive Types (MONET). [#monetpaper]_ [#monetcode]_
MONET operates in two distinct phases to extract meaningful information from multi-omics datasets. In the first
phase, it constructs an edge-weighted graph for each omic, where the nodes represent individual samples, and the
weights indicate the similarity between samples within that particular omic. Moving on to the second phase, MONET
identifies modules by identifying dense subgraphs that are shared across multiple omic graphs.
The resulting output comprises a collection of modules, each representing a subset of samples. These modules are
mutually exclusive, meaning that samples are assigned to only one module. It is important to note that not all
samples are necessarily assigned to a module; those remaining unassigned are referred to as "lonely" samples.
Each module, denoted as M, is characterized by its constituent samples, referred to as samples(M), and the set of
omics it encompasses, denoted as omics(M). Intuitively, samples(M) exhibit similarity with one another
specifically within the omics(M) context.
Parameters
----------
n_clusters : Ignored
Ignored.
num_repeats : int (default=15)
Times the algorithm will be repeated in order to avoid suboptimal (local maximum) solutions. The best solution
will be returned.
similarity_mode : str (default='prob')
One of ['prob', 'corr']. If 'corr', the weighting scheme is computed basen on correlation; if 'prob',
a probabilistic formulation is used.
init_modules : dict (default=None)
an optional module initialization for MONET. A dict mapping between module names to sample ids. All modules
are initialized to cover all views. Set to None to use MONET's seed finding algorithm for initialization.
iters : int (default=500)
Maximal number of iterations.
num_of_seeds : int (default=10)
Number of seeds to create in MONET's module initialization algorithm.
num_of_samples_in_seed : int (default=10)
Number of samples to put in each seeds to create in MONET's module initialization algorithm.
min_mod_size : int (default=10)
Minimal size (number of samples) for a MONET module.
max_samples_per_action : int (default=10)
Maximal number of samples in a single MONET action (maximal number of samples added to a module or replaced
between modules in a single action).
percentile_remove_edge : int (default=None)
Only edges with weight percentile above (for positive weights) or below (for negative weights) this percentile
are kept in the graph. For example, percentile_remove_edge=90 keeps only the 10% edges with highest positive
weight and lowest negative weight in the graph. one keeps all edges in the graph.
random_state : int (default=None)
Determines the randomness. Use an int to make the randomness deterministic.
verbose : bool, default=False
Verbosity mode.
n_jobs : int (default=None)
The number of jobs to run in parallel. None means 1 unless in a joblib.parallel_backend context. -1 means
using all processors.
Attributes
----------
labels_ : array-like of shape (n_samples,)
Labels of each point in training data.
glob_var_ : dict
Module names to Module objects mapping. Every module instance includes its set of
samples (under the "samples" attribute) and its set of views (the "views" attribute).
total_weight_ : float
Sum of the weights (similarity between samples within the module) of all modules.
mod_graphs_ : list of dataframes of shape (n_samples, n_samples)
Graph of each modality.
mod_views_ : list of length n_mods.
Views used for each module.
n_clusters_ : int
Number of clusters.
References
----------
.. [#monetpaper] Rappoport N, Safra R, Shamir R (2020) MONET: Multi-omic module discovery by omic selection. PLoS
Comput Biol 16(9): e1008182. https://doi.org/10.1371/journal.pcbi.1008182.
.. [#monetcode] https://github.com/Shamir-Lab/MONET
Example
--------
>>> import numpy as np
>>> import pandas as pd
>>> from imml.cluster import MONET
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> estimator = MONET()
>>> labels = estimator.fit_predict(Xs)
"""
def __init__(self, n_clusters: int = None, num_repeats: int = 15, similarity_mode: str = 'corr',
init_modules: dict = None, iters: int = 500, num_of_seeds: int = 10,
num_of_samples_in_seed: int = 10, min_mod_size: int = 10, max_sams_per_action: int = 10,
percentile_remove_edge: int = None, random_state: int = None,
verbose: bool = False, n_jobs: int = None):
similarity_mode_opts = ["corr"]
if similarity_mode not in similarity_mode_opts:
raise ValueError(f"Invalid similarity_mode. Expected one of {similarity_mode_opts}. {similarity_mode} was passed.")
self.n_clusters = n_clusters
self.num_repeats = num_repeats
self.similarity_mode = similarity_mode
self.init_modules = init_modules
self.iters = iters
self.num_of_seeds = num_of_seeds
self.num_of_samples_in_seed = num_of_samples_in_seed
self.min_mod_size = min_mod_size
self.max_sams_per_action = max_sams_per_action
self.percentile_remove_edge = percentile_remove_edge
self.random_state = random_state
self.verbose = verbose
self.n_jobs = n_jobs
# a list of the actions considered by MONET in each iteration. Each action correponds to one function in the list.
self.functions = [_best_samples_to_add, _which_sample_to_remove, _which_view_to_add_to_module,
_which_view_to_remove_from_module, _score_of_split_module, _weight_of_split_and_add_view,
_weight_of_split_and_remove_view, _weight_of_new_module, _top_samples_to_switch,
_weight_of_spreading_module, _weight_of_merged_modules]
[docs]
def fit(self, Xs, y=None):
r"""
Fit the transformer to the input data.
Parameters
----------
Xs : list of array-likes objects
- Xs length: n_mods
- Xs[i] shape: (n_samples_i, n_features_i)
A list of different views.
y : array-like, shape (n_samples,)
Labels for each sample. Only used by supervised algorithms.
Returns
-------
self : returns an instance of self.
"""
Xs = check_Xs(Xs, ensure_all_finite='allow-nan')
if not isinstance(Xs[0], pd.DataFrame):
Xs = [pd.DataFrame(X) for X in Xs]
for X in Xs:
X.index = X.index.astype(str)
samples = Xs[0].index
Xs = remove_missing_samples_by_mod(Xs=Xs)
data = {}
if self.similarity_mode == "corr":
data = self._process_data(Xs=Xs)
solutions = Parallel(n_jobs=self.n_jobs)(
delayed(self._single_run)(
data=data, init_modules=self.init_modules, iters=self.iters, num_of_seeds=self.num_of_seeds,
num_of_samples_in_seed=self.num_of_samples_in_seed, min_mod_size=self.min_mod_size,
max_sams_per_action=self.max_sams_per_action, percentile_remove_edge=self.percentile_remove_edge,
samples = samples, verbose=self.verbose,
random_state=self.random_state + n_time if self.random_state is not None else self.random_state)
for n_time in range(self.num_repeats)
)
solutions = {idx: i for idx, i in enumerate(solutions)}
best_sol = {key: value['total_weight'] for key, value in solutions.items()}
best_sol = max(best_sol.items(), key=operator.itemgetter(1))[0]
best_sol = solutions[best_sol]
glob_var, total_weight = best_sol['glob_var'], best_sol['total_weight']
labels, view_graphs, mod_views = self._post_processing(glob_var=glob_var)
labels = labels.loc[samples].squeeze().values
labels_wo_nan = np.unique(labels, return_inverse=True)[1].astype(float)
labels_wo_nan[np.isnan(labels)] = np.nan
self.labels_ = labels_wo_nan
self.glob_var_ = glob_var
self.total_weight_ = total_weight
self.view_graphs_ = view_graphs
self.mod_views_ = mod_views
self.n_clusters_ = len(np.unique(labels_wo_nan[~np.isnan(labels_wo_nan)]))
return self
def _predict(self, Xs):
r"""
Return clustering results for samples.
Parameters
----------
Xs : list of array-likes objects
- Xs length: n_mods
- Xs[i] shape: (n_samples_i, n_features_i)
A list of different views.
Returns
-------
labels : list of array-likes objects, shape (n_samples,)
The predicted data.
"""
labels = self.labels_
return labels
[docs]
def fit_predict(self, Xs, y=None):
r"""
Fit the model and return clustering results.
Convenience method; equivalent to calling fit(X) followed by predict(X).
Parameters
----------
Xs : list of array-likes objects
- Xs length: n_mods
- Xs[i] shape: (n_samples_i, n_features_i)
A list of different views.
Returns
-------
labels : ndarray, shape (n_samples,)
The predicted data.
"""
labels = self.fit(Xs)._predict(Xs)
return labels
def _single_run(self, data, init_modules, iters, num_of_seeds, num_of_samples_in_seed, min_mod_size,
max_sams_per_action, percentile_remove_edge, samples, random_state, verbose):
r"""
"""
if random_state is None:
random_state = np.random.default_rng().integers(100000)
glob_var = _Globals(len(self.functions))
glob_var = self._create_env(samples = samples, glob_var=glob_var, data=data,
percentile_remove_edge=percentile_remove_edge)
glob_var.min_mod_size = min_mod_size
glob_var.max_samps_per_action = max_sams_per_action
if init_modules is None:
glob_var = self._get_seeds(glob_var, num_of_seeds=num_of_seeds,
num_of_samples_in_seed=num_of_samples_in_seed, random_state=random_state)
else:
glob_var = self._create_seeds_from_solution(glob_var, init_modules)
for some_mod in glob_var.modules.copy().values():
if len(some_mod.samples) < min_mod_size:
if verbose:
print('killing a small module before starting')
glob_var.kill_module(some_mod)
total_weight = sum(mod.get_weight() for mod in glob_var.modules.values())
converged_modules = {}
did_action = False
iterations = 0
while iterations < iters:
prev_weight = total_weight
active_module_names = list(sorted(set(glob_var.modules.keys()) - set(converged_modules.keys())))
if len(active_module_names) == 0:
if not did_action:
if verbose:
print("converged, total score: {}.".format(total_weight))
break
else:
converged_modules = {}
did_action = False
active_module_names = list(sorted(glob_var.modules.keys()))
mod_name = np.random.default_rng(random_state + iterations).choice(active_module_names).tolist()
mod = glob_var.modules[mod_name]
max_res = self._get_next_step(mod, glob_var)
glob_var = self._exec_next_step(mod, max_res, glob_var)
for _, some_mod in glob_var.modules.copy().items():
if len(some_mod.get_samples()) <= 1 or not some_mod.get_views():
glob_var.kill_module(some_mod)
if verbose:
print('removing zombie module')
total_weight = sum([mod.get_weight() for name, mod in glob_var.modules.items()])
iterations += 1
if (iterations % 10 == 0) and verbose:
print("iteration: " + str(iterations))
print("num of modules: " + str(len(glob_var.modules)))
print("total_weight: " + str(total_weight))
print("actions: " + str(glob_var.actions))
# Assert module sizes
for _, some_mod in glob_var.modules.copy().items():
assert len(some_mod.samples) >= min_mod_size
if total_weight <= prev_weight or max_res[1][0] == -float("inf"):
if mod_name in glob_var.modules:
converged_modules.update({mod_name: glob_var.modules[mod_name]})
else: # the score deviates from the score we expected
if not (abs(total_weight - prev_weight - max_res[1][0]) < 0.01):
# This signifies a bug and should never occur:
# that the difference in the objective function from the
# previous iteration is different from the difference
# the algorithm expected for the function.
raise Exception("The clusters could not be found.")
did_action = True
assert abs(total_weight - prev_weight - max_res[1][0]) < 0.01
did_action = True
for mod_name, mod in glob_var.modules.copy().items():
if mod.get_size() <= glob_var.min_mod_size and not self._is_mod_significant(mod, glob_var,
random_state=random_state):
if verbose:
print("module {} with samples {} on views {} is not significant.".format((mod_name, mod),
mod.get_samples(),
mod.get_views().keys()))
glob_var.kill_module(mod)
return {"glob_var": glob_var, "total_weight": total_weight}
@staticmethod
def _process_data(Xs: list):
"""gets raw data and return a list of similarity matrices"""
data = {}
for X_idx, X in enumerate(Xs):
X_t = X.copy().T
X_t.columns = X_t.columns
X_t = X_t.corr()
np.fill_diagonal(X_t.values, 0)
data[str(X_idx)] = X_t
return data
def _create_env(self, samples, glob_var, data, percentile_remove_edge):
"""
Create all the variables used during MONET's run:
modules, modality, etc, and associating them with a Global instance.
"""
all_sam_names = set(samples)
glob_var.samples = {sample: _Sample(sample) for sample in all_sam_names}
for view, dat in data.items():
self.view = view
graph, means, covs, percentile = self._build_a_graph_similarity(dat)
if percentile_remove_edge is not None:
all_weights = [graph.edges[edge]['weight'] for edge in graph.edges]
all_weights_array = np.array(all_weights)
positive_thresh = np.percentile(all_weights_array[all_weights_array > 0], percentile_remove_edge)
negative_thresh = np.percentile(all_weights_array[all_weights_array < 0], 100 - percentile_remove_edge)
all_edges = [edge for edge in graph.edges]
for edge in all_edges:
cur_weight = graph.edges[edge]['weight']
if (cur_weight > 0 and cur_weight < positive_thresh) or (cur_weight < 0 and cur_weight > negative_thresh):
graph.remove_edge(edge[0], edge[1])
cur_graph_sams = set(graph.nodes)
missing_sams = all_sam_names - cur_graph_sams
for missing_sam in missing_sams:
graph.add_node(missing_sam)
glob_var.views.update({view: _View(graph=graph, name=view)})
glob_var.gmm_params.update({view: {'mean': means, 'cov': covs, 'percentile': percentile}})
return glob_var
def _create_seeds_from_solution(self, glob_var, init_modules):
for mod_name, sam_ids in init_modules.items():
views = glob_var.views
sam_dict = {}
for sam_id in sam_ids:
sam_dict[sam_id] = glob_var.samples[sam_id]
mod_weight = 0
for view in views.values():
mod_weight += view.graph.subgraph(list(sam_dict.keys())).size('weight')
_Module(glob_var=glob_var, samples=sam_dict, views=views, weight=mod_weight)
return glob_var
def _get_seeds(self, glob_var, num_of_seeds=3, num_of_samples_in_seed=10, random_state: int = None):
"""
Create seed modules.
"""
lst = list(glob_var.views.items())
lst.sort(key=lambda x: x[0])
views_list = [view for name, view in lst]
sam_list = list(glob_var.samples.keys())
adj = np.zeros((len(sam_list), len(sam_list)))
for name, view in lst:
adj += nx.adjacency_matrix(view.graph.subgraph(sam_list), nodelist=sam_list)
adj = pd.DataFrame(adj, index=sam_list, columns=sam_list)
joined_subgraph = nx.from_pandas_adjacency(adj)
view_graphs = [joined_subgraph]
for i in range(num_of_seeds):
view_graph = view_graphs[0]
cur_nodes = list(sorted(view_graph.nodes()))
adj = list(view_graph.adjacency())
if len(cur_nodes) == 0:
break
rand_sam_index = np.random.default_rng(random_state + i).integers(0, max([len(cur_nodes) - 1, 1]))
rand_sam_name = cur_nodes[rand_sam_index]
rand_sam_in_adj = [sam[0] for sam in adj].index(rand_sam_name)
neighbors = [(key, adj[rand_sam_in_adj][1][key]['weight']) for key in adj[rand_sam_in_adj][1]]
neighbors = sorted(neighbors, key=lambda x: x[1], reverse=True)[:(num_of_samples_in_seed - 1)]
nodes = {rand_sam_name: glob_var.samples[rand_sam_name]}
for nei in neighbors:
if nei[1] > 0 and nei[0] != rand_sam_name:
nodes.update({nei[0]: glob_var.samples[nei[0]]})
mod_weight = view_graph.subgraph(list(nodes.keys())).size('weight')
if mod_weight > 0 and len(nodes) > 1 and len(nodes) >= glob_var.min_mod_size:
_Module(glob_var=glob_var, samples=nodes, views=[view for view in views_list], weight=mod_weight)
for k in range(len(view_graphs)):
view_graph = view_graphs[k]
remaining_nodes = list(sorted(set(cur_nodes) - set(nodes.keys())))
view_graphs[k] = view_graph.subgraph(remaining_nodes)
return glob_var
def _build_a_graph_similarity(self, distances):
g = nx.from_numpy_array(distances.values)
mapping = {i: j for i, j in enumerate(distances.columns)}
nx.relabel_nodes(g, mapping, False)
return g, [], [], 0
def _get_next_step(self, mod, glob_var):
"""
this function decided what is the next action that will be executed.
"""
max_res = (-1, (-float("inf"), None))
for func_i in range(len(self.functions)):
if func_i <= 9: # only one module needed
tmp = self.functions[func_i](mod, glob_var)
if tmp[0] > max_res[1][0]:
max_res = (func_i, tmp)
else:
for mod2 in glob_var.modules.values():
if mod2 == mod:
continue
tmp = self.functions[func_i](mod, mod2, glob_var)
if tmp[0] > max_res[1][0]:
max_res = (func_i, tmp)
return max_res
def _exec_next_step(self, mod, max_res, glob_var):
"""
this function actually performs an action, given that the
algorithm already decided what the next action will be.
"""
if max_res[1][0] == -float("inf") or max_res[1][0] < 0:
return glob_var
func_i = max_res[0]
glob_var.actions[func_i] += 1
if func_i == 0: # add
for sample in max_res[1][1]:
mod.add_sample(sample)
elif func_i == 1: # remove
mod.remove_sample(max_res[1][1])
if len(mod.get_samples()) <= 1:
glob_var = glob_var.kill_module(mod)
elif func_i == 2: # add view
mod.add_view(max_res[1][1], glob_var)
elif func_i == 3: # remove view
mod.remove_view(max_res[1][1], glob_var)
elif func_i == 4: # split
glob_var = mod.split_module(max_res[1][1][1], glob_var)
elif func_i == 5: # split and add view
glob_var = mod.split_and_add_view(view=max_res[1][1][0], sub_nodes=max_res[1][1][1], glob_var=glob_var)
elif func_i == 6: # split and remove view
glob_var = mod.split_and_remove_view(view=max_res[1][1][0], sub_nodes=max_res[1][1][1], glob_var=glob_var)
elif func_i == 7: # create new module
new_mod = _Module(glob_var)
new_mod.add_view(max_res[1][1][1], glob_var)
for sam in max_res[1][1][0]:
new_mod.add_sample(glob_var.samples[sam])
elif func_i == 8: # transfer
sams = [(sam, mod2) for sam, weight, mod2 in max_res[1][1]]
for sam, mod2 in sams:
_switch_2_samples(glob_var.samples[sam], mod, mod2, glob_var)
elif func_i == 9: # spread module
mod.spread_module(max_res[1][1], glob_var)
elif func_i == 10: # merge
glob_var = mod.merge_with_module(max_res[1], glob_var)
return glob_var
def _is_mod_significant(self, mod, glob_var, percentile=95, iterations=500, random_state: int = None):
"""
Assess the statisitcal significance of a module by sampling modules or similar size.
"""
draws = [0 for i in range(iterations)]
mod_size = len(mod.get_samples())
if mod_size <= 1:
return False
for i in range(iterations):
samps = np.random.default_rng(random_state + i).choice(list(glob_var.samples.keys()),
mod_size).tolist()
lst = list(mod.get_views().items())
lst.sort(key=lambda x: x[0])
for name, view in lst:
draws[i] += view.graph.subgraph(samps).size('weight')
num_to_beat = np.percentile(draws, percentile)
return mod.get_weight() > num_to_beat
@staticmethod
def _post_processing(glob_var):
labels = [[sample, mod_id] for mod_id, module in glob_var.modules.items() for sample in module.samples]
labels = pd.DataFrame(labels)
labels = labels.set_index(0)
sams_without_mods = pd.DataFrame(None, index=glob_var.samples.keys())
labels = pd.concat([labels, sams_without_mods.loc[sams_without_mods.index.difference(labels.index)]])
view_graphs = [pd.DataFrame(nx.to_numpy_array(view.graph)) for view in glob_var.views.values()]
mod_views = {mod_name: list(module.get_views().keys()) for mod_name, module in glob_var.modules.items()}
return labels, view_graphs, mod_views
# def _get_em_graph_per_view(self, Xs: list, predictions=False):
# """gets raw data and return a list of similarity matrices"""
#
# sim_data = [1 - X.T.corr() for X in Xs]
# all_views_ret = []
# for i, cur_sim in enumerate(sim_data):
# if predictions:
# num_clusters = self.n_clusters_[i]
# else:
# n_clusters = list(range(2, 11))
# scores = [silhouette_score(cur_sim,
# SpectralClustering(n_clusters=k, affinity='precomputed', n_jobs = -1,
# random_state= self.random_state).fit_predict(cur_sim)) \
# for k in n_clusters]
# num_clusters = n_clusters[np.argmax(scores)]
#
# if predictions:
# em_ret = self.all_ems_[i]
# else:
# chosen_sims_mat = cur_sim.sample(frac=1., random_state=42)
# chosen_sims_mat = chosen_sims_mat[chosen_sims_mat.index]
# chosen_sims = chosen_sims_mat.values[np.triu_indices_from(chosen_sims_mat, k=1)]
# em_ret = GaussianMixture(n_components=2, n_init=20, max_iter=int(1e5),
# random_state= self.random_state).fit(pd.DataFrame(chosen_sims))
#
# # calculate probabilities
# sigma = [np.sqrt(np.trace(em_ret.covariances_[i]) / 2) for i in range(2)]
# prob1 = np.log(em_ret.weights_[0]) + norm.logpdf(cur_sim, loc=em_ret.means_[0], scale=sigma[0])
# prob2 = np.log(em_ret.weights_[1]) + norm.logpdf(cur_sim, loc=em_ret.means_[1], scale=sigma[1])
# prob = prob1 - prob2 if em_ret.means_[0] < em_ret.means_[1] else prob2 - prob1
# shift_by = np.quantile(prob[np.triu_indices_from(prob, k=1)], 1 - 1 / num_clusters)
# prob = prob - shift_by
# np.fill_diagonal(prob, 0)
# prob = pd.DataFrame(prob, index = Xs[i].index.astype(str), columns=Xs[i].index.astype(str))
# all_views_ret.append({"prob": prob, "all_views_ems": em_ret, "num_clusters": num_clusters})
# return all_views_ret
# def _monet_ret_to_module_membership(self, view_graphs=None):
# if view_graphs is None:
# view_graphs = self.view_graphs_
# mod_views = self.mod_views_
# mod_names = list(self.glob_var_.modules.keys())
# samples = self.labels_.index
# all_module_membership = []
# for mod_name in mod_names:
# cur_mod_views = mod_views[mod_name]
# cur_mod_samples = samples[self.labels_ == mod_name] + 'fit'
# cur_module_membership = sum([view_graphs[int(i)][cur_mod_samples].sum(axis=0) for i in cur_mod_views])
# all_module_membership.append(cur_module_membership.tolist())
# all_module_membership = pd.DataFrame(all_module_membership, index=mod_names)
# all_module_membership = all_module_membership.T.idxmax(1)
# all_module_membership = all_module_membership[all_module_membership.index.difference(samples.astype(int))]
# return all_module_membership