# License: BSD-3-Clause
import os
import zipfile
from os.path import dirname
import numpy as np
from PIL import Image
from ._m3care import NMT_tran, MM_transformer_encoder, init_weights, PositionalEncoding, clones, \
GraphConvolution, length_to_mask, guassian_kernel
from .. import deepmodule_installed, deepmodule_error, LightningModule, Module
if deepmodule_installed:
import torch
from torch import optim, nn
from torchvision import models as models
import torch.nn.functional as F
import torchvision.transforms as transforms
[docs]
class M3Care(LightningModule):
r"""
Missing Modalities in Multimodal healthcare data (M3Care). [#m3carepaper]_ [#m3carecode]_ [#m3carecode2]_
M3Care is a multimodal classification framework that handles missing modalities by imputing latent
task-relevant information using similar samples, based on a modality-adaptive similarity metric.
It supports heterogeneous input types (e.g., tabular, text, vision).
This class provides training, validation, testing, and prediction logic compatible with the
`Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
Parameters
----------
input_dim : list of int, default=None
A list specifying the input dimensions for each tabular modality.
hidden_dim : int, default=128
Hidden dimension size.
embed_size : int, default=128
Size of the shared embedding space where modalities are projected.
modalities : list of str, default=None
Names of the modalities. Options are "tabular", "text" and "image".
vocab : list, default=None
List with path to corpus file, maximum number of words in vocabulary, and freq_cutoff
(if word occurs n < freq_cutoff times, drop the word). If you want to pass your own Vocab object, use just
a list with one element [Vocab]. If None, ["test.de-en.en", 50000, 2] will be
used (if applicable). [#m3carecode]_
learning_rate : float, default=1e-4
Learning rate for the optimizer.
weight_decay : float, default=1e-4
Weight decay used by the optimizer.
output_dim : int, default=1
Number of classes in your response variable. Typically 1 for binary classification.
loss_fn : callable, default=None
Loss function. If None, defaults to `nn.BCEWithLogitsLoss()` if output_dim == <=2, else `nn.CrossEntropyLoss()`.
keep_prob : float, default=0.5
Dropout keep probability used in MLP layers.
extractors : list of nn.Module, default=None
List of custom feature extractors for each modality. If None, defaults will be used.
References
----------
.. [#m3carepaper] Zhang, Chaohe, et al. "M3care: Learning with missing modalities in multimodal healthcare data."
Proceedings of the 28th ACM SIGKDD conference on knowledge discovery and data mining. 2022.
.. [#m3carecode] https://github.com/choczhang/M3Care/
.. [#m3carecode2] https://github.com/pcyin/pytorch_basic_nmt/tree/master
See Also
--------
:class:`~imml.load.M3CareDataset`
Example
--------
>>> from lightning import Trainer
>>> import numpy as np
>>> import pandas as pd
>>> from torch.utils.data import DataLoader
>>> from imml.classify import M3Care
>>> from imml.load import M3CareDataset
>>> from imml.ampute import Amputer
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((2, 10)))]
>>> Xs.append(pd.DataFrame(np.random.default_rng(42).random((2, 15))))
>>> Xs.append(pd.DataFrame(["docs/figures/graph.png", "docs/figures/logo_imml.png"]))
>>> Xs.append(pd.DataFrame(["This is the graphical abstract of iMML.", "This is the logo of iMML."]))
>>> Xs = Amputer(p=0.2, random_state=42).fit_transform(Xs) # this step is optional
>>> y = pd.Series(np.random.default_rng(42).integers(0, 2, len(Xs[0])), dtype=np.float32)
>>> train_data = M3CareDataset(Xs=Xs, y=y)
>>> train_dataloader = DataLoader(dataset=train_data, batch_size=10, shuffle=True)
>>> trainer = Trainer(max_epochs=1, logger=False, enable_checkpointing=False)
>>> modalities = ["tabular", "tabular", "image", "text"]
>>> estimator = M3Care(modalities=modalities, input_dim=[X.shape[1] for X,mod in zip(Xs, modalities) if mod=="tabular"])
>>> trainer.fit(estimator, train_dataloader)
>>> trainer.predict(estimator, train_dataloader)
"""
def __init__(self, input_dim: list = None, hidden_dim: int = 128, embed_size: int = 128, modalities: list = None,
vocab: list = None, learning_rate: float = 1e-4, weight_decay: float = 1e-4, output_dim: int = 1,
loss_fn: callable = None, keep_prob: float = 0.5, extractors: list = None):
if not deepmodule_installed:
raise ImportError(deepmodule_error)
if input_dim is not None and not isinstance(input_dim, list):
raise ValueError(f"Invalid input_dim. It must be a list. A {type(input_dim)} was passed.")
if not isinstance(hidden_dim, int):
raise ValueError(f"Invalid hidden_dim. It must be an integer. A {type(hidden_dim)} was passed.")
if hidden_dim <= 0:
raise ValueError(f"Invalid hidden_dim. It must be positive. {hidden_dim} was passed.")
if not isinstance(embed_size, int):
raise ValueError(f"Invalid embed_size. It must be an integer. A {type(embed_size)} was passed.")
if embed_size <= 0:
raise ValueError(f"Invalid embed_size. It must be positive. {embed_size} was passed.")
if not isinstance(modalities, list):
raise ValueError(f"Invalid modalities. It must be a list. A {type(modalities)} was passed.")
if len(modalities) < 2:
raise ValueError(f"Invalid modalities. It must have at least two modalities. Got {len(modalities)} modalities")
modalities_options = ["tabular", "text", "image"]
if not all(mod in modalities_options for mod in modalities):
raise ValueError(f"Invalid modalities. Expected options are: {modalities_options}")
if not isinstance(learning_rate, float):
raise ValueError(f"Invalid learning_rate. It must be a float. A {type(learning_rate)} was passed.")
if learning_rate <= 0:
raise ValueError(f"Invalid learning_rate. It must be positive. {learning_rate} was passed.")
if not isinstance(weight_decay, float):
raise ValueError(f"Invalid weight_decay. It must be a float. A {type(weight_decay)} was passed.")
if weight_decay < 0:
raise ValueError(f"Invalid weight_decay. It must be non-negative. {weight_decay} was passed.")
if not isinstance(output_dim, int):
raise ValueError(f"Invalid output_dim. It must be an integer. A {type(output_dim)} was passed.")
if output_dim <= 0:
raise ValueError(f"Invalid output_dim. It must be positive. {output_dim} was passed.")
if loss_fn is not None and not callable(loss_fn):
raise ValueError(f"Invalid loss_fn. It must be callable. A {type(loss_fn)} was passed.")
if not isinstance(keep_prob, float):
raise ValueError(f"Invalid keep_prob. It must be a float. A {type(keep_prob)} was passed.")
if keep_prob <= 0 or keep_prob > 1:
raise ValueError(f"Invalid keep_prob. It must be between 0 and 1. {keep_prob} was passed.")
if extractors is not None and not isinstance(extractors, list):
raise ValueError(f"Invalid extractors. It must be a list. A {type(extractors)} was passed.")
if vocab is None:
vocab_folder = dirname(__file__)
vocab_filename = "test.de-en.en"
vocab_folder = os.path.join(vocab_folder, "_" + (os.path.basename(__file__).split(".")[0]))
vocab_path = os.path.join(vocab_folder, vocab_filename)
if not os.path.exists(vocab_path):
download_vocab = input(
"You have not provided a vocab. If you want to use the default vocab, could you allow the download? Enter a bool"
)
if not bool(download_vocab):
raise ValueError("You have not provided a vocab. Please provide a vocab or allow the download.")
try:
os.system("wget http://www.cs.cmu.edu/~pengchey/iwslt2014_ende.zip")
with zipfile.ZipFile("iwslt2014_ende.zip", 'r') as zip_ref:
file_path = zip_ref.getinfo(os.path.join("data", vocab_filename))
file_path.filename = os.path.basename(file_path.filename)
zip_ref.extract(file_path, vocab_folder)
except:
raise
finally:
os.remove("iwslt2014_ende.zip")
vocab = [vocab_path, 50000, 2]
elif not isinstance(vocab, list):
raise ValueError(f"Invalid vocab. It must be a list. A {type(vocab)} was passed.")
super().__init__()
self.model = M3CareModule(input_dim=input_dim, hidden_dim=hidden_dim, embed_size=embed_size, vocab=vocab,
modalities=modalities, output_dim=output_dim, keep_prob=keep_prob, extractors=extractors)
self.learning_rate = learning_rate
self.weight_decay = weight_decay
if loss_fn is None:
loss_fn = nn.BCEWithLogitsLoss() if output_dim == 1 else nn.CrossEntropyLoss()
self.loss_fn = loss_fn
if output_dim == 1:
self.get_probs = nn.Sigmoid()
else:
self.get_probs = nn.Softmax(dim=-1)
[docs]
def training_step(self, batch, batch_idx=None):
r"""
Method required for training using `Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
"""
Xs, y, observed_mod_indicator = batch
y_pred, _ = self.model(Xs=Xs, observed_mod_indicator=observed_mod_indicator)
loss = self.loss_fn(y_pred, y)
return loss
[docs]
def validation_step(self, batch, batch_idx=None):
r"""
Method required for validating using `Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
"""
Xs, y, observed_mod_indicator = batch
y_pred, _ = self.model(Xs=Xs, observed_mod_indicator=observed_mod_indicator)
loss = self.loss_fn(y_pred, y)
return loss
[docs]
def test_step(self, batch, batch_idx=None):
r"""
Method required for testing using `Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
"""
Xs, y, observed_mod_indicator = batch
y_pred, _ = self.model(Xs=Xs, observed_mod_indicator=observed_mod_indicator)
loss = self.loss_fn(y_pred, y)
return loss
[docs]
def predict_step(self, batch, batch_idx=None):
r"""
Method required for predicting using `Lightning Trainer <https://lightning.ai/docs/pytorch/stable/common/trainer.html>`_.
"""
Xs, y, observed_mod_indicator = batch
y_pred, _ = self.model(Xs=Xs, observed_mod_indicator=observed_mod_indicator)
y_pred = self.get_probs(y_pred)
return y_pred
class M3CareModule(Module):
def __init__(self, input_dim: list = None, hidden_dim: int = 128, embed_size: int = 128, modalities: list = None,
vocab = None, output_dim: int =1, keep_prob: float = 1, extractors: list = None):
super().__init__()
self.hidden_dim = hidden_dim
self.modalities = modalities
self.output_dim = output_dim
self.keep_prob = keep_prob
self.n_mods = len(modalities)
if extractors is None:
extractors = [None] * len(modalities)
if input_dim is not None:
self.input_dim = iter(input_dim)
for i, (mod, extractor) in enumerate(zip(self.modalities, extractors)):
if mod == "tabular":
if extractor is None:
extractor = nn.Linear(next(self.input_dim), hidden_dim)
elif mod == "text":
if extractor is None:
extractor = NMT_tran(embed_size=embed_size, hidden_size=hidden_dim,
dropout_rate=1 - self.keep_prob, vocab=vocab)
elif mod == "image":
if extractor is None:
self.preprocess_img = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
extractor = nn.Sequential(models.resnet18(),
nn.Linear(1000, self.hidden_dim)
)
setattr(self, f"extractor{i}", extractor)
self.MM_model1 = MM_transformer_encoder(input_dim=self.hidden_dim, d_model=self.hidden_dim, \
MHD_num_head=4, d_ff=self.hidden_dim * 4, output_dim=1)
self.MM_model2 = MM_transformer_encoder(input_dim=self.hidden_dim, d_model=self.hidden_dim, \
MHD_num_head=1, d_ff=self.hidden_dim * 4, output_dim=1)
self.token_type_embeddings = nn.Embedding(6, self.hidden_dim)
self.token_type_embeddings.apply(init_weights)
self.PositionalEncoding = PositionalEncoding(self.hidden_dim, dropout=0, max_len=5000)
self.dropout = nn.Dropout(p=1 - self.keep_prob)
self.proj1 = nn.Linear(self.hidden_dim * len(self.modalities), self.hidden_dim * 2)
self.out_layer = nn.Linear(self.hidden_dim * 2, self.output_dim)
self.threshold = nn.Parameter(torch.ones(size=(1,)) + 1)
self.simiProj = nn.Linear(self.hidden_dim, self.hidden_dim)
self.bn = nn.BatchNorm1d(self.hidden_dim)
self.simiProj = clones(torch.nn.Sequential(
torch.nn.Linear(self.hidden_dim, self.hidden_dim, bias=True), nn.ReLU(),
torch.nn.Linear(self.hidden_dim, self.hidden_dim, bias=True), nn.ReLU(),
torch.nn.Linear(self.hidden_dim, self.hidden_dim, bias=True),
), self.n_mods)
self.GCN1 = clones(GraphConvolution(self.hidden_dim, self.hidden_dim, bias=True), self.n_mods)
self.GCN2 = clones(GraphConvolution(self.hidden_dim, self.hidden_dim, bias=True), self.n_mods)
self.weight1 = clones(nn.Linear(self.hidden_dim, 1), self.n_mods)
self.weight2 = clones(nn.Linear(self.hidden_dim, 1), self.n_mods)
self.eps = nn.ParameterList([nn.Parameter(torch.ones(1)+1) for _ in range(self.n_mods)])
def forward(self, Xs, observed_mod_indicator):
feats = []
hidden00 = []
mask_mats = []
mask_mats_ = []
mask2_mats = []
for X_idx, (X,mod) in enumerate(zip(Xs, self.modalities)):
extractor = getattr(self, f"extractor{X_idx}")
if mod == 'tabular':
X = X.clone()
X[X.isnan().all(1)] = 0
feat = extractor(X)
feat = F.relu(feat)
if len(X) == 1:
mask = torch.ones((2, 1)).int().squeeze()[:1]
else:
mask = torch.ones((feat.shape[0], 1)).int().squeeze()
feat_00 = feat.clone()
elif mod == 'image':
X = [self.preprocess_img(Image.open(img_path).convert('RGB')
if bool(img_path) else Image.new("RGB", (256, 256), (0, 0, 0)))
for img_path in X]
X = torch.stack(X)
feat = extractor(X)
feat = F.relu(feat)
if len(X) == 1:
mask = torch.ones((2, 1)).int().squeeze()[:1]
else:
mask = torch.ones((feat.shape[0], 1)).int().squeeze()
feat_00 = feat.clone()
elif mod == 'text':
X = [f"[CLS] {s}".split() if bool(s) else "[CLS] None".split() for s in X]
feat, lens = extractor(X)
feat = F.relu(feat)
mask = torch.from_numpy(np.array(lens))
feat_00 = feat[:,0].clone()
else:
raise ValueError(f"Unknown modality type: {mod}")
mask = length_to_mask(mask).unsqueeze(1).to(feat.device).int()
mask_ = observed_mod_indicator[:, [X_idx]]
mask2 = mask_ * mask_.permute(1,0)
feats.append(feat)
hidden00.append(feat_00)
mask_mats.append(mask)
mask_mats_.append(mask_)
mask2_mats.append(mask2)
sim_mats = []
diffs = []
for i, h in enumerate(hidden00):
km1 = guassian_kernel(self.bn(F.relu(self.simiProj[i](h))), kernel_mul=2.0, kernel_num=3)
km2 = guassian_kernel(self.bn(h), kernel_mul=2.0, kernel_num=3)
sim = ((1 - torch.sigmoid(self.eps[i])) * km1 + torch.sigmoid(self.eps[i])) * km2
if self.modalities[i] == "text":
sim = sim * mask2_mats[i]
sim_mats.append(sim)
diff = torch.abs(torch.norm(self.simiProj[i](h), dim=1) - torch.norm(h, dim=1))
diffs.append(diff)
sum_of_diff = torch.stack(diffs, dim=1).sum(dim=1)
sim_sum = torch.stack(sim_mats, dim=0).sum(dim=0)
mask_sum = torch.stack(mask2_mats, dim=0).sum(dim=0)
similar_score = sim_sum / mask_sum
th = torch.sigmoid(self.threshold)[0]
similar_score = F.relu(similar_score - th)
bin_mask = similar_score > 0
similar_score = similar_score + bin_mask * th.detach()
final_h = []
gs = []
for i, (h,mask2) in enumerate(zip(hidden00, mask2_mats)):
g = F.relu(self.GCN1[i](similar_score*mask2, h))
g = F.relu(self.GCN2[i](similar_score*mask2, g))
gs.append(g)
w1 = torch.sigmoid(self.weight1[i](g))
w2 = torch.sigmoid(self.weight2[i](h))
w1 = w1 / (w1 + w2)
w2 = 1 - w1
final = w1 * g + w2 * h
final_h.append(final)
embs = []
for X_idx, (h,mod,final,g,mask,mask_) in enumerate(zip(feats, self.modalities, final_h, gs, mask_mats, mask_mats_)):
h_ = torch.zeros_like(h)
h_ = h_ + h
if mod == "text":
h_[mask_[:, 0], 0] = final[mask_[:, 0]]
h_[torch.logical_not(mask_[:, 0]), 0] = g[torch.logical_not(mask_[:, 0])]
emb = self.PositionalEncoding(h_)
else:
h_[mask_[:, 0]] = final[mask_[:, 0]]
h_[torch.logical_not(mask_[:, 0])] = g[torch.logical_not(mask_[:, 0])]
emb = h.unsqueeze(1)
mask = torch.ones_like(mask.permute(0,2,1).squeeze(-1)).to(h_.device).long()
emb = emb + self.token_type_embeddings(X_idx * mask)
embs.append(emb)
z0 = torch.cat(embs, dim=1)
z0_mask = torch.cat(mask_mats, dim=-1).int()
z1 = F.relu(self.MM_model1(z0, z0_mask))
z2 = F.relu(self.MM_model2(z1, z0_mask))
combined_hidden = [z2[:, X_idx] for X_idx in range(len(self.modalities))]
combined_hidden = torch.cat(combined_hidden, dim=-1)
last_hs_proj = self.dropout(F.relu(self.proj1(combined_hidden)))
output = self.out_layer(last_hs_proj)
output = output.squeeze(dim=1)
return output, sum_of_diff