Source code for imml.load.ragpt_dataset

from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image

try:
    import lightning.pytorch as pl
    from transformers import BertTokenizer
    import torch
    from ..classify._ragpt.core_tools import resize_image
    from ..classify._ragpt.vilt import ViltImageProcessor
    deepmodule_installed = True
except ImportError:
    deepmodule_installed = False
    deepmodule_error = "Module 'Deep' needs to be installed."

TorchDatasetBase = torch.utils.data.Dataset if deepmodule_installed else object
BertTokenizer = BertTokenizer if deepmodule_installed else object
ViltImageProcessor = ViltImageProcessor if deepmodule_installed else object


[docs] class RAGPTDataset(TorchDatasetBase): r""" This class provides a `torch.utils.data.Dataset` implementation for handling multi-modal datasets with `RAGPT`. If it is used with `torch.utils.data.DataLoader`, the `collate_fn` argument of the `DataLoader` constructor should be `RAGPTCollator`. Parameters ---------- database : pd.DataFrame (n_samples, 14) A database with the retrieval-augmented prompts created by MCR. It must contain the following columns: - item_id: Unique identifier for each sample. - img_path: Path to the image file. - text: Textual content of the sample. - label: Label of the sample. - observed_image: Indicator of whether the image was observed. - observed_text: Indicator of whether the text was observed. - i2i_id_list: List of ids of the retrieved items for the image-to-image modality. - i2i_sims_list: List of similarities of the retrieved items for the image-to-image modality. - i2i_label_list: List of labels of the retrieved items for the image-to-image modality. - prompt_image_path: Path to the generated image prompt. Only if generate_cap is True. - t2t_id_list: List of ids of the retrieved items for the text-to-text modality. - t2t_sims_list: List of similarities of the retrieved items for the text-to-text modality. - t2t_label_list: List of labels of the retrieved items for the text-to-text modality. - prompt_text_path: Path to the generated text prompt. Only if generate_cap is True. max_text_len : int, default=128 Maximum token length for text inputs (used during prompt generation). Returns ------- sample : dict Dictionary with the following keys for one sample: - image: Image of the sample. - text: Textual content of the sample. - label: Label of the sample. - r_t_list: List with retrieved textual content for the sample. - r_i_list: List with retrieved image content for the sample. - r_l_list: List with retrieved labels for the sample. - observed_text: True if the text is observed, False otherwise. - observed_image: True if the image is observed, False otherwise. Example -------- >>> from torch.utils.data import DataLoader >>> from imml.load import RAGPTDataset >>> from imml.retrieve import MCR >>> images = ["docs/figures/graph.png", "docs/figures/logo_imml.png", "docs/figures/graph.png", "docs/figures/logo_imml.png"] >>> texts = ["This is the graphical abstract of iMML.", "This is the logo of iMML.", "This is the graphical abstract of iMML.", "This is the logo of iMML."] >>> Xs = [images, texts] >>> y = [0, 1, 0, 1] >>> modalities = ["image", "text"] >>> estimator = MCR(modalities=modalities) >>> database = estimator.fit_transform(Xs=Xs, y=y) >>> train_data = RAGPTDataset(database=database) >>> train_dataloader = DataLoader(train_data, collate_fn=RAGPTCollator()) """ def __init__(self, database: pd.DataFrame, max_text_len: int = 128): if not deepmodule_installed: raise ImportError(deepmodule_error) if not isinstance(database, pd.DataFrame): raise ValueError(f"Invalid database. It must be a pandas DataFrame. A {type(database)} was passed.") required_columns = ['img_path', 'text', 'label', 'i2i_id_list', 't2t_id_list', 'prompt_image_path', 'prompt_text_path', 'i2i_label_list', 't2t_label_list', 'observed_image', 'observed_text'] missing_columns = [col for col in required_columns if col not in database.columns] if missing_columns: raise ValueError(f"Invalid database. It is missing required columns: {missing_columns}") if not isinstance(max_text_len, int): raise ValueError(f"Invalid max_text_len. It must be an integer. A {type(max_text_len)} was passed.") if max_text_len <= 0: raise ValueError(f"Invalid max_text_len. It must be positive. {max_text_len} was passed.") super().__init__() self.max_text_len = max_text_len self.img_path_list = database['img_path'].tolist() self.text_list = database['text'].tolist() self.label_list = database['label'].tolist() self.i2i_list = database['i2i_id_list'].tolist() self.t2t_list = database['t2t_id_list'].tolist() self.prompt_image_path = database['prompt_image_path'].tolist() self.prompt_text_path = database['prompt_text_path'].tolist() self.i2i_r_l_list_list = database['i2i_label_list'].tolist() self.t2t_r_l_list_list = database['t2t_label_list'].tolist() self.observed_image = database['observed_image'].tolist() self.observed_text = database['observed_text'].tolist() def __getitem__(self, index): text = self.text_list[index] image = self.img_path_list[index] image = Image.open(image) if pd.notna(image) else Image.new("RGBA", (256, 256), (0, 0, 0)) image = image.convert("RGB") label = self.label_list[index] observed_text = self.observed_text[index] observed_image = self.observed_image[index] prompt_image_path = self.prompt_image_path[index] prompt_text_path = self.prompt_text_path[index] r_i_list = [] r_t_list = [] if (observed_text == 0) and (observed_image == 1): text = "I love deep learning" * 1024 r_l_list = self.i2i_r_l_list_list[index] for i in range(len(prompt_image_path)): base = prompt_image_path[i] r_i_list.append(np.load(base).tolist()) base= Path(*[("text" if p == "image" else p) for p in Path(base).parts]) r_t_list.append(np.load(base).tolist()) elif (observed_text == 1) and (observed_image == 0): r_l_list = self.t2t_r_l_list_list[index] for i in range(len(prompt_text_path)): base = prompt_text_path[i] r_t_list.append(np.load(base).tolist()) base= Path(*[("image" if p == "text" else p) for p in Path(base).parts]) r_i_list.append(np.load(base).tolist()) elif (observed_text == 1) and (observed_image == 1): r_l_list = self.i2i_r_l_list_list[index] for prompt_image,prompt_text in zip(prompt_image_path, prompt_text_path): r_i_list.append(np.load(prompt_image).tolist()) r_t_list.append(np.load(prompt_text).tolist()) else: raise ValueError(f"No available modalities for item: {index}") return { "image": image, "text": text, "label": label, "r_t_list": r_t_list, "r_i_list": r_i_list, "r_l_list": r_l_list, "observed_text": observed_text, "observed_image": observed_image } def __len__(self): return len(self.label_list)
class RAGPTCollator(): def __init__(self, tokenizer = None, image_processor = None, max_text_len: int = 128): if not deepmodule_installed: raise ImportError(deepmodule_error) if tokenizer is not None and not isinstance(tokenizer, BertTokenizer): raise ValueError(f"Invalid tokenizer. It must be a BertTokenizer. A {type(tokenizer)} was passed.") if image_processor is not None and not isinstance(image_processor, ViltImageProcessor): raise ValueError(f"Invalid image_processor. It must be a ViltImageProcessor. A {type(image_processor)} was passed.") if not isinstance(max_text_len, int): raise ValueError(f"Invalid max_text_len. It must be an integer. A {type(max_text_len)} was passed.") if max_text_len <= 0: raise ValueError(f"Invalid max_text_len. It must be positive. {max_text_len} was passed.") if tokenizer is None: tokenizer = BertTokenizer.from_pretrained('dandelin/vilt-b32-mlm', do_lower_case=True) if image_processor is None: image_processor = ViltImageProcessor.from_pretrained('dandelin/vilt-b32-mlm') self.tokenizer = tokenizer self.image_processor = image_processor self.max_text_len = max_text_len def __call__(self, batch): text = [item['text'] for item in batch] image = [item['image'] for item in batch] label = [item['label'] for item in batch] r_t_list = [item['r_t_list'] for item in batch] r_i_list = [item['r_i_list'] for item in batch] observed_text = [item['observed_text'] for item in batch] observed_image = [item['observed_image'] for item in batch] r_l_list = [item['r_l_list'] for item in batch] text_encoding = self.tokenizer( text, padding="max_length", truncation=True, max_length=self.max_text_len, return_special_tokens_mask=True, ) input_ids = text_encoding['input_ids'] attention_mask = text_encoding['attention_mask'] token_type_ids = text_encoding['token_type_ids'] image = [resize_image(img) for img in image] image_encoding = self.image_processor(image, return_tensors="pt") pixel_values = image_encoding["pixel_values"] pixel_mask = image_encoding["pixel_mask"] input_ids = torch.tensor(input_ids,dtype=torch.int64) token_type_ids = torch.tensor(token_type_ids,dtype=torch.int64) attention_mask = torch.tensor(attention_mask,dtype=torch.int64) label = torch.tensor(label,dtype=torch.float) r_l_list = torch.tensor(r_l_list,dtype=torch.long) r_t_list = torch.tensor(r_t_list,dtype=torch.float) r_i_list = torch.tensor(r_i_list,dtype=torch.float) return { "input_ids": torch.tensor(input_ids,dtype=torch.int64), "pixel_values": pixel_values, "pixel_mask": pixel_mask, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "label": label, "r_t_list": r_t_list, "r_i_list": r_i_list, "r_l_list": r_l_list, "observed_image": torch.tensor(observed_image,dtype=torch.int64), "observed_text": torch.tensor(observed_text,dtype=torch.int64) }