Classify

Missing Modalities in Multimodal healthcare data (M3Care)

class imml.classify.M3Care(input_dim: List = None, hidden_dim: int = 128, embed_size: int = 128, modalities: List = None, vocab: list = None, learning_rate: float = 0.0001, weight_decay: float = 0.0001, output_dim: int = 1, loss_fn: callable = None, keep_prob: float = 0.5, extractors: List = None)[source]

Bases: object

Missing Modalities in Multimodal healthcare data (M3Care). [1] [2] [3]

M3Care is a multimodal classification framework that handles missing modalities by imputing latent task-relevant information using similar samples, based on a modality-adaptive similarity metric. It supports heterogeneous input types (e.g., tabular, text, vision).

This class provides training, validation, testing, and prediction logic compatible with the Lightning AI Trainer.

Parameters:
  • input_dim (list of int, default=None) -- A list specifying the input dimensions for each tabular modality.

  • hidden_dim (int, default=128) -- Hidden dimension size.

  • embed_size (int, default=128) -- Size of the shared embedding space where modalities are projected.

  • modalities (list of str, default=None) -- Names of the modalities. Options are "tabular", "text" and "image".

  • vocab (list, default=None) -- List with path to corpus file, freq_cutoff (if word occurs n < freq_cutoff times, drop the word), and maximum number of words in vocabulary. If you want to pass your own Vocab object, use just a list with one element [Vocab]. If None, ["train.de-en.en", 50000, 2] will be used (if applicable). [2]

  • learning_rate (float, default=1e-4) -- Learning rate for the optimizer.

  • weight_decay (float, default=1e-4) -- Weight decay used by the optimizer.

  • output_dim (int, default=1) -- Number of output dimensions. Typically 1 for binary classification.

  • loss_fn (callable, default=None) -- Loss function. If None, defaults to nn.BCEWithLogitsLoss().

  • keep_prob (float, default=0.5) -- Dropout keep probability used in MLP layers.

  • extractors (list of nn.Module, default=None) -- List of custom feature extractors for each modality. If None, defaults will be used.

References

Example

>>> from imml.classify import M3Care
>>> from lightning import Trainer
>>> import torch
>>> import numpy as np
>>> import pandas as pd
>>> from torch.utils.data import DataLoader
>>> from imml.impute import get_observed_mod_indicator
>>> from imml.load import M3CareDataset
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> train_data = M3CareDataset(Xs=[torch.from_numpy(X.values).float() for X in Xs],
                               y=torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float(),
                               observed_mod_indicator=torch.from_numpy(get_observed_mod_indicator(Xs).values))
>>> train_dataloader = DataLoader(dataset=train_data, batch_size=10, shuffle=True)
>>> trainer = Trainer(max_epochs=2, logger=False, enable_checkpointing=False)
>>> estimator = M3Care(modalities= ["tabular", "tabular"], input_dim=[X.shape[1] for X in Xs])
>>> trainer.fit(estimator, train_dataloader)
>>> trainer.predict(estimator, train_dataloader)
training_step(batch, batch_idx=None)[source]

Method required for training using Lightning AI Trainer.

validation_step(batch, batch_idx=None)[source]

Method required for validating using Lightning AI Trainer.

test_step(batch, batch_idx=None)[source]

Method required for testing using Lightning AI Trainer.

predict_step(batch, batch_idx=None)[source]

Method required for predicting using Lightning AI Trainer.

configure_optimizers()[source]

Method required for training using Lightning AI Trainer.

Mutual-consistent graph contrastive learning (MUSE)

class imml.classify.MUSE(input_dim: List = None, hidden_dim: int = 128, modalities: List = None, tokenizer=None, learning_rate: float = 0.0002, weight_decay: float = 0.0, output_dim: int = 1, extractors: List = None, gnn_layers: int = 2, gnn_norm: str = None, code_pretrained_embedding: bool = True, bert_type: str = 'prajjwal1/bert-tiny', dropout: float = 0.25)[source]

Bases: LightningModule

Mutual-consistent graph contrastive learning (MUSE). [4] [5]

MUSE is a multimodal representation learning framework designed to handle missing modalities and partially labeled data. It uses a bipartite graph between samples and modalities to support arbitrary missingness patterns and a mutual-consistent contrastive loss to encourage the learning of label-discriminative, modality-consistent features.

This class provides training, validation, testing, and prediction logic compatible with the Lightning AI Trainer.

Parameters:
  • input_dim (list of int, default=None) -- A list specifying the input dimensions for each tabular/series modality.

  • hidden_dim (int, default=128) -- Hidden dimension size.

  • modalities (list of str, default=None) -- Names of the modalities. Options are "tabular", "text" and "series".

  • tokenizer (str, default=None) -- Tokenizer to use for text modality. If None, defaults to "emilyalsentzer/Bio_ClinicalBERT" tokenizer.

  • learning_rate (float, default=2e-4) -- Learning rate for the optimizer.

  • weight_decay (float, default=0) -- Weight decay used by the optimizer.

  • output_dim (int, default=1) -- Number of output dimensions. Typically 1 for binary classification.

  • extractors (list of nn.Module, default=None) -- List of custom feature extractors for each modality. If None, defaults will be used.

  • gnn_layers (int, default=2) -- Number of GNN layers used to propagate sample-modality representations.

  • gnn_norm (str or None, default=None) -- Optional normalization strategy in GNN layers (e.g., 'batchnorm', 'layernorm').

  • code_pretrained_embedding (bool, default=True) -- If True, initializes pretrained embeddings for text/code features.

  • bert_type (str, default="prajjwal1/bert-tiny") -- HuggingFace model name or path for BERT backbone used in the text encoder.

  • dropout (float, default=0.25) -- Dropout rate applied in the encoders and classifier head.

References

Example

>>> from imml.classify import MUSE
>>> from lightning import Trainer
>>> import torch
>>> import numpy as np
>>> import pandas as pd
>>> from torch.utils.data import DataLoader
>>> from imml.load import MUSEDataset
>>> from imml.impute import get_observed_mod_indicator
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((20, 10))) for i in range(3)]
>>> train_data = MUSEDataset(Xs=[torch.from_numpy(X.values).float() for X in Xs],
                             observed_mod_indicator=torch.from_numpy(get_observed_mod_indicator(Xs).values),
                             y=torch.from_numpy(np.random.default_rng(42).integers(0, 2, len(Xs[0]))).float(),
                             y_indicator=torch.ones((len(Xs[0]))).bool()
                             )
>>> train_dataloader = DataLoader(dataset=train_data, batch_size=10, shuffle=True)
>>> trainer = Trainer(max_epochs=2, logger=False, enable_checkpointing=False)
>>> estimator = MUSE(modalities= ["tabular", "tabular"], input_dim=[X.shape[1] for X in Xs])
>>> trainer.fit(estimator, train_dataloader)
>>> trainer.predict(estimator, train_dataloader)
training_step(batch, batch_idx=None)[source]

Method required for training using Lightning AI Trainer.

validation_step(batch, batch_idx=None)[source]

Method required for validating using Lightning AI Trainer.

test_step(batch, batch_idx=None)[source]

Method required for testing using Lightning AI Trainer.

predict_step(batch, batch_idx=None)[source]

Method required for predicting using Lightning AI Trainer.

configure_optimizers()[source]

Method required for training using Lightning AI Trainer.

Retrieval-AuGmented dynamic Prompt Tuning (RAGPT)

class imml.classify.RAGPT(vilt: ViltModel = None, max_text_len: int = 128, max_image_len: int = 145, prompt_position: int = 0, prompt_length: int = 1, dropout_rate: float = 0.2, hidden_dim: int = 768, cls_num: int = 2, loss: callable = None, learning_rate: float = 0.001, weight_decay: float = 0.02)[source]

Bases: LightningModule

Retrieval-AuGmented dynamic Prompt Tuning (RAGPT). [6] [7]

RAGPT is designed for incomplete vision-language learning, where one modality may be missing at inference or training time. It combines three core modules to address this challenge: 1) Multi-Channel Retriever, which retrieves semantically similar instances from a training database, per modality; 2) Missing Modality Generator, which fills in missing modality data using context from retrieved neighbors; and 3) Context-Aware Prompter, which generates dynamic prompts based on context to improve downstream classification in multimodal transformers.

This class provides training, validation, testing, and prediction logic compatible with the Lightning AI Trainer.

Parameters:
  • vilt (transformers.ViltModel, optional) -- Pretrained model used for joint vision-language encoding. If None, defaults to ViltModel.from_pretrained('dandelin/vilt-b32-mlm').

  • max_text_len (int, default=128) -- Maximum number of tokens for text inputs.

  • max_image_len (int, default=145) -- Maximum number of image patches/tokens processed by the vision encoder.

  • prompt_position (int, default=0) -- Index position at which to insert dynamic prompts in the transformer input sequence.

  • prompt_length (int, default=1) -- Number of prompt tokens to insert for dynamic prompt tuning.

  • dropout_rate (float, default=0.2) -- Dropout probability.

  • hidden_dim (int, default=768) -- Hidden dimension size.

  • cls_num (int, default=2) -- Number of target classes for classification tasks.

  • loss (callable, optional) -- Loss function. If None, defaults to F.cross_entropy.

  • learning_rate (float, default=1e-3) -- Learning rate for the optimizer.

  • weight_decay (float, default=2e-2) -- Weight decay used by the optimizer.

References

Example

>>> from imml.retrieve import MCR
>>> from imml.load import RAGPTDataset, RAGPTCollator
>>> from imml.classify import RAGPT
>>> from lightning import Trainer
>>> from torch.utils.data import DataLoader
>>> images = ["docs/figures/graph.png", "docs/figures/logo_imml.png",
              "docs/figures/graph.png", "docs/figures/logo_imml.png"]
>>> texts = ["This is the graphical abstract of iMML.", "This is the logo of iMML.",
             "This is the graphical abstract of iMML.", "This is the logo of iMML."]
>>> Xs = [images, texts]
>>> y = [0, 1, 0, 1]
>>> modalities = ["image", "text"]
>>> estimator = MCR(modalities=modalities)
>>> database = estimator.fit_transform(Xs=Xs, y=y)
>>> train_data = RAGPTDataset(database=database)
>>> train_dataloader = DataLoader(train_data, collate_fn=RAGPTCollator)
>>> trainer = Trainer(max_epochs=2, logger=False, enable_checkpointing=False)
>>> estimator = RAGPT()
>>> trainer.fit(estimator, train_dataloader)
>>> trainer.predict(estimator, train_dataloader)
training_step(batch, batch_idx=None)[source]

Method required for training using Lightning AI Trainer.

validation_step(batch, batch_idx=None)[source]

Method required for validating using Lightning AI Trainer.

test_step(batch, batch_idx=None)[source]

Method required for testing using Lightning AI Trainer.

predict_step(batch, batch_idx=None)[source]

Method required for predicting using Lightning AI Trainer.

configure_optimizers()[source]

Method required for training using Lightning AI Trainer.