Classify

Missing Modalities in Multimodal healthcare data (M3Care)

class imml.classify.M3Care(input_dim: list = None, hidden_dim: int = 128, embed_size: int = 128, modalities: list = None, vocab: list = None, learning_rate: float = 0.0001, weight_decay: float = 0.0001, output_dim: int = 1, loss_fn: callable = None, keep_prob: float = 0.5, extractors: list = None)[source]

Bases: object

Missing Modalities in Multimodal healthcare data (M3Care). [1] [2] [3]

M3Care is a multimodal classification framework that handles missing modalities by imputing latent task-relevant information using similar samples, based on a modality-adaptive similarity metric. It supports heterogeneous input types (e.g., tabular, text, vision).

This class provides training, validation, testing, and prediction logic compatible with the Lightning Trainer.

Parameters:
  • input_dim (list of int, default=None) -- A list specifying the input dimensions for each tabular modality.

  • hidden_dim (int, default=128) -- Hidden dimension size.

  • embed_size (int, default=128) -- Size of the shared embedding space where modalities are projected.

  • modalities (list of str, default=None) -- Names of the modalities. Options are "tabular", "text" and "image".

  • vocab (list, default=None) -- List with path to corpus file, maximum number of words in vocabulary, and freq_cutoff (if word occurs n < freq_cutoff times, drop the word). If you want to pass your own Vocab object, use just a list with one element [Vocab]. If None, ["test.de-en.en", 50000, 2] will be used (if applicable). [2]

  • learning_rate (float, default=1e-4) -- Learning rate for the optimizer.

  • weight_decay (float, default=1e-4) -- Weight decay used by the optimizer.

  • output_dim (int, default=1) -- Number of classes in your response variable. Typically 1 for binary classification.

  • loss_fn (callable, default=None) -- Loss function. If None, defaults to nn.BCEWithLogitsLoss() if output_dim == <=2, else nn.CrossEntropyLoss().

  • keep_prob (float, default=0.5) -- Dropout keep probability used in MLP layers.

  • extractors (list of nn.Module, default=None) -- List of custom feature extractors for each modality. If None, defaults will be used.

References

See also

M3CareDataset

Example

>>> from lightning import Trainer
>>> import numpy as np
>>> import pandas as pd
>>> from torch.utils.data import DataLoader
>>> from imml.classify import M3Care
>>> from imml.load import M3CareDataset
>>> from imml.ampute import Amputer
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((2, 10)))]
>>> Xs.append(pd.DataFrame(np.random.default_rng(42).random((2, 15))))
>>> Xs.append(pd.DataFrame(["docs/figures/graph.png", "docs/figures/logo_imml.png"]))
>>> Xs.append(pd.DataFrame(["This is the graphical abstract of iMML.", "This is the logo of iMML."]))
>>> Xs = Amputer(p=0.2, random_state=42).fit_transform(Xs) # this step is optional
>>> y = pd.Series(np.random.default_rng(42).integers(0, 2, len(Xs[0])), dtype=np.float32)
>>> train_data = M3CareDataset(Xs=Xs, y=y)
>>> train_dataloader = DataLoader(dataset=train_data, batch_size=10, shuffle=True)
>>> trainer = Trainer(max_epochs=1, logger=False, enable_checkpointing=False)
>>> modalities = ["tabular", "tabular", "image", "text"]
>>> estimator = M3Care(modalities=modalities, input_dim=[X.shape[1] for X,mod in zip(Xs, modalities) if mod=="tabular"])
>>> trainer.fit(estimator, train_dataloader)
>>> trainer.predict(estimator, train_dataloader)
training_step(batch, batch_idx=None)[source]

Method required for training using Lightning Trainer.

validation_step(batch, batch_idx=None)[source]

Method required for validating using Lightning Trainer.

test_step(batch, batch_idx=None)[source]

Method required for testing using Lightning Trainer.

predict_step(batch, batch_idx=None)[source]

Method required for predicting using Lightning Trainer.

configure_optimizers()[source]

Method required for training using Lightning Trainer.

MUtual-conSistEnt graph contrastive learning (MUSE)

class imml.classify.MUSE(input_dim: list = None, hidden_dim: int = 128, modalities: list = None, tokenizer=None, learning_rate: float = 0.0002, weight_decay: float = 0.0, output_dim: int = 1, extractors: list = None, gnn_layers: int = 2, gnn_norm: str = None, loss_fn: callable = None, bert_type: str = 'prajjwal1/bert-tiny', dropout: float = 0.25)[source]

Bases: object

MUtual-conSistEnt graph contrastive learning (MUSE). [4] [5]

MUSE is a multimodal representation learning framework designed to handle missing modalities and partially labeled data. It uses a bipartite graph between samples and modalities to support arbitrary missingness patterns and a mutual-consistent contrastive loss to encourage the learning of label-discriminative, modality-consistent features.

This class provides training, validation, testing, and prediction logic compatible with the Lightning Trainer.

Parameters:
  • input_dim (list of int, default=None) -- A list specifying the input dimensions for each tabular/series modality.

  • hidden_dim (int, default=128) -- Hidden dimension size.

  • modalities (list of str, default=None) -- Names of the modalities. Options are "tabular", "text" and "series".

  • tokenizer (str, default=None) -- Tokenizer to use for text modality. If None, defaults to "emilyalsentzer/Bio_ClinicalBERT" tokenizer.

  • learning_rate (float, default=2e-4) -- Learning rate for the optimizer.

  • weight_decay (float, default=0) -- Weight decay used by the optimizer.

  • output_dim (int, default=1) -- Number of classes in your response variable. Typically 1 for binary classification.

  • extractors (list of nn.Module, default=None) -- List of custom feature extractors for each modality. If None, defaults will be used.

  • gnn_layers (int, default=2) -- Number of GNN layers used to propagate sample-modality representations.

  • gnn_norm (str or None, default=None) -- Optional normalization strategy in GNN layers (e.g., 'batchnorm', 'layernorm').

  • loss_fn (callable, default=None) -- Loss function. If None, defaults to nn.BCEWithLogitsLoss() if output_dim == <=2, else nn.CrossEntropyLoss().

  • bert_type (str, default="prajjwal1/bert-tiny") -- HuggingFace model name or path for BERT backbone used in the text encoder.

  • dropout (float, default=0.25) -- Dropout rate applied in the encoders and classifier head.

References

See also

MUSEDataset

Example

>>> from lightning import Trainer
>>> import numpy as np
>>> import pandas as pd
>>> from torch.utils.data import DataLoader
>>> from imml.classify import MUSE
>>> from imml.load import MUSEDataset
>>> from imml.ampute import Amputer
>>> Xs = [pd.DataFrame(np.random.default_rng(42).random((2, 10)))]
>>> Xs.append(pd.DataFrame(np.random.default_rng(42).random((2, 15))))
>>> Xs.append(pd.DataFrame(["This is the graphical abstract of iMML.", "This is the logo of iMML."]))
>>> Xs = Amputer(p=0.2, random_state=42).fit_transform(Xs) # this step is optional
>>> y = pd.Series(np.random.default_rng(42).integers(0, 2, len(Xs[0])), dtype=np.float32)
>>> train_data = MUSEDataset(Xs=Xs, y=y)
>>> train_dataloader = DataLoader(dataset=train_data, batch_size=10, shuffle=True)
>>> trainer = Trainer(max_epochs=1, logger=False, enable_checkpointing=False)
>>> modalities = ["tabular", "tabular", "text"]
>>> estimator = MUSE(modalities=modalities, input_dim=[Xs[0].shape[1], Xs[1].shape[1]])
>>> trainer.fit(estimator, train_dataloader)
>>> trainer.predict(estimator, train_dataloader)
training_step(batch, batch_idx=None)[source]

Method required for training using Lightning Trainer.

validation_step(batch, batch_idx=None)[source]

Method required for validating using Lightning Trainer.

test_step(batch, batch_idx=None)[source]

Method required for testing using Lightning Trainer.

predict_step(batch, batch_idx=None)[source]

Method required for predicting using Lightning Trainer.

configure_optimizers()[source]

Method required for training using Lightning Trainer.

Retrieval-AuGmented dynamic Prompt Tuning (RAGPT)

class imml.classify.RAGPT(vilt: object = None, max_text_len: int = 40, max_image_len: int = 145, prompt_position: int = 0, prompt_length: int = 1, dropout_rate: float = 0.2, hidden_dim: int = 768, output_dim: int = 2, loss_fn: callable = None, learning_rate: float = 0.001, weight_decay: float = 0.02)[source]

Bases: object

Retrieval-AuGmented dynamic Prompt Tuning (RAGPT). [6] [7]

RAGPT is designed for incomplete vision-language learning, where one modality may be missing at inference or training time. It combines three core modules to address this challenge: 1) Multi-Channel Retriever, which retrieves semantically similar instances from a training database, per modality; 2) Missing Modality Generator, which fills in missing modality data using context from retrieved neighbors; and 3) Context-Aware Prompter, which generates dynamic prompts based on context to improve downstream classification in multimodal transformers.

This class provides training, validation, testing, and prediction logic compatible with the Lightning Trainer.

Parameters:
  • vilt (transformers.ViltModel, optional) -- Pretrained model used for joint vision-language encoding. If None, defaults to ViltModel.from_pretrained('dandelin/vilt-b32-mlm').

  • max_text_len (int, default=40) -- Maximum token length for text inputs (used during prompt generation). Must not exceed the max_position_embeddings of the ViLT model (default: 40 for 'dandelin/vilt-b32-mlm').

  • max_image_len (int, default=145) -- Maximum number of image patches/tokens processed by the vision encoder.

  • prompt_position (int, default=0) -- Index position at which to insert dynamic prompts in the transformer input sequence.

  • prompt_length (int, default=1) -- Number of prompt tokens to insert for dynamic prompt tuning.

  • dropout_rate (float, default=0.2) -- Dropout probability.

  • hidden_dim (int, default=768) -- Hidden dimension size.

  • output_dim (int, default=2) -- Number of classes in your response variable. Typically 2 for binary classification.

  • loss_fn (callable, default=None) -- Loss function. If None, defaults to nn.BCEWithLogitsLoss() if output_dim == <=2, else nn.CrossEntropyLoss().

  • learning_rate (float, default=1e-3) -- Learning rate for the optimizer.

  • weight_decay (float, default=2e-2) -- Weight decay used by the optimizer.

References

See also

RAGPTDataset, MCR

Example

>>> from imml.retrieve import MCR
>>> from imml.load import RAGPTDataset, RAGPTCollator
>>> from imml.classify import RAGPT
>>> from lightning import Trainer
>>> from torch.utils.data import DataLoader
>>> images = ["docs/figures/graph.png", "docs/figures/logo_imml.png",
              "docs/figures/graph.png", "docs/figures/logo_imml.png"]
>>> texts = ["This is the graphical abstract of iMML.", "This is the logo of iMML.",
             "This is the graphical abstract of iMML.", "This is the logo of iMML."]
>>> Xs = [images, texts]
>>> y = [0, 1, 0, 1]
>>> modalities = ["image", "text"]
>>> estimator = MCR(modalities=modalities)
>>> database = estimator.fit_transform(Xs=Xs, y=y)
>>> train_data = RAGPTDataset(database=database)
>>> train_dataloader = DataLoader(train_data, collate_fn=RAGPTCollator)
>>> trainer = Trainer(max_epochs=2, logger=False, enable_checkpointing=False)
>>> estimator = RAGPT()
>>> trainer.fit(estimator, train_dataloader)
>>> trainer.predict(estimator, train_dataloader)
training_step(batch, batch_idx=None)[source]

Method required for training using Lightning Trainer.

validation_step(batch, batch_idx=None)[source]

Method required for validating using Lightning Trainer.

test_step(batch, batch_idx=None)[source]

Method required for testing using Lightning Trainer.

predict_step(batch, batch_idx=None)[source]

Method required for predicting using Lightning Trainer.

configure_optimizers()[source]

Method required for training using Lightning Trainer.