Missing Modalities in Multimodal healthcare data (M3Care). [1][2][3]
M3Care is a multimodal classification framework that handles missing modalities by imputing latent
task-relevant information using similar samples, based on a modality-adaptive similarity metric.
It supports heterogeneous input types (e.g., tabular, text, vision).
This class provides training, validation, testing, and prediction logic compatible with the
Lightning AI Trainer.
Parameters:
input_dim (list of int, default=None) -- A list specifying the input dimensions for each tabular modality.
embed_size (int, default=128) -- Size of the shared embedding space where modalities are projected.
modalities (list of str, default=None) -- Names of the modalities. Options are "tabular", "text" and "image".
vocab (list, default=None) -- List with path to corpus file, freq_cutoff (if word occurs n < freq_cutoff times, drop the word), and
maximum number of words in vocabulary. If you want to pass your own Vocab object, use just a list with one
element [Vocab]. If None, ["train.de-en.en", 50000, 2] will be used (if applicable). [2]
learning_rate (float, default=1e-4) -- Learning rate for the optimizer.
weight_decay (float, default=1e-4) -- Weight decay used by the optimizer.
output_dim (int, default=1) -- Number of output dimensions. Typically 1 for binary classification.
loss_fn (callable, default=None) -- Loss function. If None, defaults to nn.BCEWithLogitsLoss().
keep_prob (float, default=0.5) -- Dropout keep probability used in MLP layers.
extractors (list of nn.Module, default=None) -- List of custom feature extractors for each modality. If None, defaults will be used.
MUSE is a multimodal representation learning framework designed to handle missing modalities and partially
labeled data. It uses a bipartite graph between samples and modalities to support arbitrary missingness patterns
and a mutual-consistent contrastive loss to encourage the learning of label-discriminative, modality-consistent
features.
This class provides training, validation, testing, and prediction logic compatible with the
Lightning AI Trainer.
Parameters:
input_dim (list of int, default=None) -- A list specifying the input dimensions for each tabular/series modality.
RAGPT is designed for incomplete vision-language learning, where one modality may be missing at
inference or training time. It combines three core modules to address this challenge: 1) Multi-Channel Retriever,
which retrieves semantically similar instances from a training database, per modality; 2) Missing Modality
Generator, which fills in missing modality data using context from retrieved neighbors; and 3) Context-Aware
Prompter, which generates dynamic prompts based on context to improve downstream classification in
multimodal transformers.
This class provides training, validation, testing, and prediction logic compatible with the
Lightning AI Trainer.
Parameters:
vilt (transformers.ViltModel, optional) -- Pretrained model used for joint vision-language encoding. If None, defaults to
ViltModel.from_pretrained('dandelin/vilt-b32-mlm').
max_text_len (int, default=128) -- Maximum number of tokens for text inputs.
max_image_len (int, default=145) -- Maximum number of image patches/tokens processed by the vision encoder.
prompt_position (int, default=0) -- Index position at which to insert dynamic prompts in the transformer input sequence.
prompt_length (int, default=1) -- Number of prompt tokens to insert for dynamic prompt tuning.
cls_num (int, default=2) -- Number of target classes for classification tasks.
loss (callable, optional) -- Loss function. If None, defaults to F.cross_entropy.
learning_rate (float, default=1e-3) -- Learning rate for the optimizer.
weight_decay (float, default=2e-2) -- Weight decay used by the optimizer.
References
Example
>>> fromimml.retrieveimportMCR>>> fromimml.loadimportRAGPTDataset,RAGPTCollator>>> fromimml.classifyimportRAGPT>>> fromlightningimportTrainer>>> fromtorch.utils.dataimportDataLoader>>> images=["docs/figures/graph.png","docs/figures/logo_imml.png", "docs/figures/graph.png", "docs/figures/logo_imml.png"]>>> texts=["This is the graphical abstract of iMML.","This is the logo of iMML.", "This is the graphical abstract of iMML.", "This is the logo of iMML."]>>> Xs=[images,texts]>>> y=[0,1,0,1]>>> modalities=["image","text"]>>> estimator=MCR(modalities=modalities)>>> database=estimator.fit_transform(Xs=Xs,y=y)>>> train_data=RAGPTDataset(database=database)>>> train_dataloader=DataLoader(train_data,collate_fn=RAGPTCollator)>>> trainer=Trainer(max_epochs=2,logger=False,enable_checkpointing=False)>>> estimator=RAGPT()>>> trainer.fit(estimator,train_dataloader)>>> trainer.predict(estimator,train_dataloader)