Missing Modalities in Multimodal healthcare data (M3Care). [1][2][3]
M3Care is a multimodal classification framework that handles missing modalities by imputing latent
task-relevant information using similar samples, based on a modality-adaptive similarity metric.
It supports heterogeneous input types (e.g., tabular, text, vision).
This class provides training, validation, testing, and prediction logic compatible with the
Lightning Trainer.
Parameters:
input_dim (list of int, default=None) -- A list specifying the input dimensions for each tabular modality.
embed_size (int, default=128) -- Size of the shared embedding space where modalities are projected.
modalities (list of str, default=None) -- Names of the modalities. Options are "tabular", "text" and "image".
vocab (list, default=None) -- List with path to corpus file, maximum number of words in vocabulary, and freq_cutoff
(if word occurs n < freq_cutoff times, drop the word). If you want to pass your own Vocab object, use just
a list with one element [Vocab]. If None, ["test.de-en.en", 50000, 2] will be
used (if applicable). [2]
learning_rate (float, default=1e-4) -- Learning rate for the optimizer.
weight_decay (float, default=1e-4) -- Weight decay used by the optimizer.
output_dim (int, default=1) -- Number of classes in your response variable. Typically 1 for binary classification.
loss_fn (callable, default=None) -- Loss function. If None, defaults to nn.BCEWithLogitsLoss() if output_dim == <=2, else nn.CrossEntropyLoss().
keep_prob (float, default=0.5) -- Dropout keep probability used in MLP layers.
extractors (list of nn.Module, default=None) -- List of custom feature extractors for each modality. If None, defaults will be used.
References
See also
M3CareDataset
Example
>>> fromlightningimportTrainer>>> importnumpyasnp>>> importpandasaspd>>> fromtorch.utils.dataimportDataLoader>>> fromimml.classifyimportM3Care>>> fromimml.loadimportM3CareDataset>>> fromimml.amputeimportAmputer>>> Xs=[pd.DataFrame(np.random.default_rng(42).random((2,10)))]>>> Xs.append(pd.DataFrame(np.random.default_rng(42).random((2,15))))>>> Xs.append(pd.DataFrame(["docs/figures/graph.png","docs/figures/logo_imml.png"]))>>> Xs.append(pd.DataFrame(["This is the graphical abstract of iMML.","This is the logo of iMML."]))>>> Xs=Amputer(p=0.2,random_state=42).fit_transform(Xs)# this step is optional>>> y=pd.Series(np.random.default_rng(42).integers(0,2,len(Xs[0])),dtype=np.float32)>>> train_data=M3CareDataset(Xs=Xs,y=y)>>> train_dataloader=DataLoader(dataset=train_data,batch_size=10,shuffle=True)>>> trainer=Trainer(max_epochs=1,logger=False,enable_checkpointing=False)>>> modalities=["tabular","tabular","image","text"]>>> estimator=M3Care(modalities=modalities,input_dim=[X.shape[1]forX,modinzip(Xs,modalities)ifmod=="tabular"])>>> trainer.fit(estimator,train_dataloader)>>> trainer.predict(estimator,train_dataloader)
MUSE is a multimodal representation learning framework designed to handle missing modalities and partially
labeled data. It uses a bipartite graph between samples and modalities to support arbitrary missingness patterns
and a mutual-consistent contrastive loss to encourage the learning of label-discriminative, modality-consistent
features.
This class provides training, validation, testing, and prediction logic compatible with the
Lightning Trainer.
Parameters:
input_dim (list of int, default=None) -- A list specifying the input dimensions for each tabular/series modality.
modalities (list of str, default=None) -- Names of the modalities. Options are "tabular", "text" and "series".
tokenizer (str, default=None) -- Tokenizer to use for text modality. If None, defaults to "emilyalsentzer/Bio_ClinicalBERT" tokenizer.
learning_rate (float, default=2e-4) -- Learning rate for the optimizer.
weight_decay (float, default=0) -- Weight decay used by the optimizer.
output_dim (int, default=1) -- Number of classes in your response variable. Typically 1 for binary classification.
extractors (list of nn.Module, default=None) -- List of custom feature extractors for each modality. If None, defaults will be used.
gnn_layers (int, default=2) -- Number of GNN layers used to propagate sample-modality representations.
gnn_norm (str or None, default=None) -- Optional normalization strategy in GNN layers (e.g., 'batchnorm', 'layernorm').
loss_fn (callable, default=None) -- Loss function. If None, defaults to nn.BCEWithLogitsLoss() if output_dim == <=2, else nn.CrossEntropyLoss().
bert_type (str, default="prajjwal1/bert-tiny") -- HuggingFace model name or path for BERT backbone used in the text encoder.
dropout (float, default=0.25) -- Dropout rate applied in the encoders and classifier head.
References
See also
MUSEDataset
Example
>>> fromlightningimportTrainer>>> importnumpyasnp>>> importpandasaspd>>> fromtorch.utils.dataimportDataLoader>>> fromimml.classifyimportMUSE>>> fromimml.loadimportMUSEDataset>>> fromimml.amputeimportAmputer>>> Xs=[pd.DataFrame(np.random.default_rng(42).random((2,10)))]>>> Xs.append(pd.DataFrame(np.random.default_rng(42).random((2,15))))>>> Xs.append(pd.DataFrame(["This is the graphical abstract of iMML.","This is the logo of iMML."]))>>> Xs=Amputer(p=0.2,random_state=42).fit_transform(Xs)# this step is optional>>> y=pd.Series(np.random.default_rng(42).integers(0,2,len(Xs[0])),dtype=np.float32)>>> train_data=MUSEDataset(Xs=Xs,y=y)>>> train_dataloader=DataLoader(dataset=train_data,batch_size=10,shuffle=True)>>> trainer=Trainer(max_epochs=1,logger=False,enable_checkpointing=False)>>> modalities=["tabular","tabular","text"]>>> estimator=MUSE(modalities=modalities,input_dim=[Xs[0].shape[1],Xs[1].shape[1]])>>> trainer.fit(estimator,train_dataloader)>>> trainer.predict(estimator,train_dataloader)
RAGPT is designed for incomplete vision-language learning, where one modality may be missing at
inference or training time. It combines three core modules to address this challenge: 1) Multi-Channel Retriever,
which retrieves semantically similar instances from a training database, per modality; 2) Missing Modality
Generator, which fills in missing modality data using context from retrieved neighbors; and 3) Context-Aware
Prompter, which generates dynamic prompts based on context to improve downstream classification in
multimodal transformers.
This class provides training, validation, testing, and prediction logic compatible with the
Lightning Trainer.
Parameters:
vilt (transformers.ViltModel, optional) -- Pretrained model used for joint vision-language encoding. If None, defaults to
ViltModel.from_pretrained('dandelin/vilt-b32-mlm').
max_text_len (int, default=40) -- Maximum token length for text inputs (used during prompt generation). Must not exceed the
max_position_embeddings of the ViLT model (default: 40 for 'dandelin/vilt-b32-mlm').
max_image_len (int, default=145) -- Maximum number of image patches/tokens processed by the vision encoder.
prompt_position (int, default=0) -- Index position at which to insert dynamic prompts in the transformer input sequence.
prompt_length (int, default=1) -- Number of prompt tokens to insert for dynamic prompt tuning.
output_dim (int, default=2) -- Number of classes in your response variable. Typically 2 for binary classification.
loss_fn (callable, default=None) -- Loss function. If None, defaults to nn.BCEWithLogitsLoss() if output_dim == <=2, else nn.CrossEntropyLoss().
learning_rate (float, default=1e-3) -- Learning rate for the optimizer.
weight_decay (float, default=2e-2) -- Weight decay used by the optimizer.
References
See also
RAGPTDataset, MCR
Example
>>> fromimml.retrieveimportMCR>>> fromimml.loadimportRAGPTDataset,RAGPTCollator>>> fromimml.classifyimportRAGPT>>> fromlightningimportTrainer>>> fromtorch.utils.dataimportDataLoader>>> images=["docs/figures/graph.png","docs/figures/logo_imml.png", "docs/figures/graph.png", "docs/figures/logo_imml.png"]>>> texts=["This is the graphical abstract of iMML.","This is the logo of iMML.", "This is the graphical abstract of iMML.", "This is the logo of iMML."]>>> Xs=[images,texts]>>> y=[0,1,0,1]>>> modalities=["image","text"]>>> estimator=MCR(modalities=modalities)>>> database=estimator.fit_transform(Xs=Xs,y=y)>>> train_data=RAGPTDataset(database=database)>>> train_dataloader=DataLoader(train_data,collate_fn=RAGPTCollator)>>> trainer=Trainer(max_epochs=2,logger=False,enable_checkpointing=False)>>> estimator=RAGPT()>>> trainer.fit(estimator,train_dataloader)>>> trainer.predict(estimator,train_dataloader)