diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 70e1000..7bb328b 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -4,4 +4,3 @@ __all__ = ["PinaDataModule", "PinaDataset"] from .data_module import PinaDataModule -from .dataset import PinaDataset diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 9ed5c64..7f467b7 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -11,203 +11,8 @@ from torch_geometric.data import Data from torch.utils.data import DataLoader, SequentialSampler, RandomSampler from torch.utils.data.distributed import DistributedSampler from ..label_tensor import LabelTensor -from .dataset import PinaDatasetFactory, PinaTensorDataset - - -class DummyDataloader: - - def __init__(self, dataset): - """ - Prepare a dataloader object that returns the entire dataset in a single - batch. Depending on the number of GPUs, the dataset is managed - as follows: - - - **Distributed Environment** (multiple GPUs): Divides dataset across - processes using the rank and world size. Fetches only portion of - data corresponding to the current process. - - **Non-Distributed Environment** (single GPU): Fetches the entire - dataset. - - :param PinaDataset dataset: The dataset object to be processed. - - .. note:: - This dataloader is used when the batch size is ``None``. - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - if len(dataset) < world_size: - raise RuntimeError( - "Dimension of the dataset smaller than world size." - " Increase the size of the partition or use a single GPU" - ) - idx, i = [], rank - while i < len(dataset): - idx.append(i) - i += world_size - self.dataset = dataset.fetch_from_idx_list(idx) - else: - self.dataset = dataset.get_all_data() - - def __iter__(self): - return self - - def __len__(self): - return 1 - - def __next__(self): - return self.dataset - - -class Collator: - """ - This callable class is used to collate the data points fetched from the - dataset. The collation is performed based on the type of dataset used and - on the batching strategy. - """ - - def __init__( - self, max_conditions_lengths, automatic_batching, dataset=None - ): - """ - Initialize the object, setting the collate function based on whether - automatic batching is enabled or not. - - :param dict max_conditions_lengths: ``dict`` containing the maximum - number of data points to consider in a single batch for - each condition. - :param bool automatic_batching: Whether automatic PyTorch batching is - enabled or not. For more information, see the - :class:`~pina.data.data_module.PinaDataModule` class. - :param PinaDataset dataset: The dataset where the data is stored. - """ - - self.max_conditions_lengths = max_conditions_lengths - # Set the collate function based on the batching strategy - # collate_pina_dataloader is used when automatic batching is disabled - # collate_torch_dataloader is used when automatic batching is enabled - self.callable_function = ( - self._collate_torch_dataloader - if automatic_batching - else (self._collate_pina_dataloader) - ) - self.dataset = dataset - - # Set the function which performs the actual collation - if isinstance(self.dataset, PinaTensorDataset): - # If the dataset is a PinaTensorDataset, use this collate function - self._collate = self._collate_tensor_dataset - else: - # If the dataset is a PinaDataset, use this collate function - self._collate = self._collate_graph_dataset - - def _collate_pina_dataloader(self, batch): - """ - Function used to create a batch when automatic batching is disabled. - - :param list[int] batch: List of integers representing the indices of - the data points to be fetched. - :return: Dictionary containing the data points fetched from the dataset. - :rtype: dict - """ - # Call the fetch_from_idx_list method of the dataset - return self.dataset.fetch_from_idx_list(batch) - - def _collate_torch_dataloader(self, batch): - """ - Function used to collate the batch - - :param list[dict] batch: List of retrieved data. - :return: Dictionary containing the data points fetched from the dataset, - collated. - :rtype: dict - """ - - batch_dict = {} - if isinstance(batch, dict): - return batch - conditions_names = batch[0].keys() - # Condition names - for condition_name in conditions_names: - single_cond_dict = {} - condition_args = batch[0][condition_name].keys() - for arg in condition_args: - data_list = [ - batch[idx][condition_name][arg] - for idx in range( - min( - len(batch), - self.max_conditions_lengths[condition_name], - ) - ) - ] - single_cond_dict[arg] = self._collate(data_list) - - batch_dict[condition_name] = single_cond_dict - return batch_dict - - @staticmethod - def _collate_tensor_dataset(data_list): - """ - Function used to collate the data when the dataset is a - :class:`~pina.data.dataset.PinaTensorDataset`. - - :param data_list: Elements to be collated. - :type data_list: list[torch.Tensor] | list[LabelTensor] - :return: Batch of data. - :rtype: dict - - :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor`. - """ - - if isinstance(data_list[0], LabelTensor): - return LabelTensor.stack(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.stack(data_list) - raise RuntimeError("Data must be Tensors or LabelTensor ") - - def _collate_graph_dataset(self, data_list): - """ - Function used to collate data when the dataset is a - :class:`~pina.data.dataset.PinaGraphDataset`. - - :param data_list: Elememts to be collated. - :type data_list: list[Data] | list[Graph] - :return: Batch of data. - :rtype: dict - - :raises RuntimeError: If the data is not a - :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`. - """ - if isinstance(data_list[0], LabelTensor): - return LabelTensor.cat(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.cat(data_list) - if isinstance(data_list[0], Data): - return self.dataset.create_batch(data_list) - raise RuntimeError( - "Data must be Tensors or LabelTensor or pyG " - "torch_geometric.data.Data" - ) - - def __call__(self, batch): - """ - Perform the collation of data fetched from the dataset. The behavoior - of the function is set based on the batching strategy during class - initialization. - - :param batch: List of retrieved data or sampled indices. - :type batch: list[int] | list[dict] - :return: Dictionary containing colleted data fetched from the dataset. - :rtype: dict - """ - - return self.callable_function(batch) +from .dataset import PinaDatasetFactory +from .dataloader import PinaDataLoader class PinaSampler: @@ -235,6 +40,19 @@ class PinaSampler: return sampler +def DataloaderCollector(): + + def __init__(self, dataloader_list): + """ + Initialize the object. + """ + assert isinstance(dataloader_list, list) + assert all( + isinstance(dataloader, DataLoader) for dataloader in dataloader_list + ) + self.dataloader_list = dataloader_list + + class PinaDataModule(LightningDataModule): """ This class extends :class:`~lightning.pytorch.core.LightningDataModule`, @@ -376,23 +194,23 @@ class PinaDataModule(LightningDataModule): if stage == "fit" or stage is None: self.train_dataset = PinaDatasetFactory( self.data_splits["train"], - max_conditions_lengths=self.find_max_conditions_lengths( - "train" - ), + # max_conditions_lengths=self.find_max_conditions_lengths( + # "train" + # ), automatic_batching=self.automatic_batching, ) if "val" in self.data_splits.keys(): self.val_dataset = PinaDatasetFactory( self.data_splits["val"], - max_conditions_lengths=self.find_max_conditions_lengths( - "val" - ), + # max_conditions_lengths=self.find_max_conditions_lengths( + # "val" + # ), automatic_batching=self.automatic_batching, ) elif stage == "test": self.test_dataset = PinaDatasetFactory( self.data_splits["test"], - max_conditions_lengths=self.find_max_conditions_lengths("test"), + # max_conditions_lengths=self.find_max_conditions_lengths("test"), automatic_batching=self.automatic_batching, ) else: @@ -502,32 +320,14 @@ class PinaDataModule(LightningDataModule): ), module="lightning.pytorch.trainer.connectors.data_connector", ) - # Use custom batching (good if batch size is large) - if self.batch_size is not None: - sampler = PinaSampler(dataset) - if self.automatic_batching: - collate = Collator( - self.find_max_conditions_lengths(split), - self.automatic_batching, - dataset=dataset, - ) - else: - collate = Collator( - None, self.automatic_batching, dataset=dataset - ) - return DataLoader( - dataset, - self.batch_size, - collate_fn=collate, - sampler=sampler, - num_workers=self.num_workers, - ) - dataloader = DummyDataloader(dataset) - dataloader.dataset = self._transfer_batch_to_device( - dataloader.dataset, self.trainer.strategy.root_device, 0 + return PinaDataLoader( + dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + collate_fn=None, + common_batch_size=True, ) - self.transfer_batch_to_device = self._transfer_batch_to_device_dummy - return dataloader def find_max_conditions_lengths(self, split): """ diff --git a/pina/data/dataloader.py b/pina/data/dataloader.py new file mode 100644 index 0000000..6855171 --- /dev/null +++ b/pina/data/dataloader.py @@ -0,0 +1,245 @@ +from torch.utils.data import DataLoader +from functools import partial +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import SequentialSampler +import torch + + +class DummyDataloader: + + def __init__(self, dataset): + """ + Prepare a dataloader object that returns the entire dataset in a single + batch. Depending on the number of GPUs, the dataset is managed + as follows: + + - **Distributed Environment** (multiple GPUs): Divides dataset across + processes using the rank and world size. Fetches only portion of + data corresponding to the current process. + - **Non-Distributed Environment** (single GPU): Fetches the entire + dataset. + + :param PinaDataset dataset: The dataset object to be processed. + + .. note:: + This dataloader is used when the batch size is ``None``. + """ + print("Using DummyDataloader") + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + if len(dataset) < world_size: + raise RuntimeError( + "Dimension of the dataset smaller than world size." + " Increase the size of the partition or use a single GPU" + ) + idx, i = [], rank + while i < len(dataset): + idx.append(i) + i += world_size + else: + idx = list(range(len(dataset))) + + self.dataset = dataset._getitem_from_list(idx) + + def __iter__(self): + return self + + def __len__(self): + return 1 + + def __next__(self): + return self.dataset + + +class PinaSampler: + """ + This class is used to create the sampler instance based on the shuffle + parameter and the environment in which the code is running. + """ + + def __new__(cls, dataset, shuffle=True): + """ + Instantiate and initialize the sampler. + + :param PinaDataset dataset: The dataset from which to sample. + :return: The sampler instance. + :rtype: :class:`torch.utils.data.Sampler` + """ + + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + sampler = DistributedSampler(dataset, shuffle=shuffle) + else: + if shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + return sampler + + +def _collect_items(batch): + """ + Helper function to collect items from a batch of graph data samples. + :param batch: List of graph data samples. + """ + to_return = {name: [] for name in batch[0].keys()} + for sample in batch: + for k, v in sample.items(): + to_return[k].append(v) + return to_return + + +def collate_fn_custom(batch, dataset): + """ + Override the default collate function to handle datasets without automatic batching. + :param batch: List of indices from the dataset. + :param dataset: The PinaDataset instance (must be provided). + """ + return dataset._getitem_from_list(batch) + + +def collate_fn_default(batch, stack_fn): + """ + Default collate function that simply returns the batch as is. + :param batch: List of data samples. + """ + print("Using default collate function") + to_return = _collect_items(batch) + return {k: stack_fn[k](v) for k, v in to_return.items()} + + +class PinaDataLoader: + """ + Custom DataLoader for PinaDataset. + """ + + def __init__( + self, + dataset_dict, + batch_size, + shuffle=False, + num_workers=0, + collate_fn=None, + common_batch_size=True, + ): + self.dataset_dict = dataset_dict + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.collate_fn = collate_fn + + print(batch_size) + + if batch_size is None: + batch_size_per_dataset = { + split: None for split in dataset_dict.keys() + } + else: + if common_batch_size: + batch_size_per_dataset = { + split: batch_size for split in dataset_dict.keys() + } + else: + batch_size_per_dataset = self._compute_batch_size() + self.dataloaders = { + split: self._create_dataloader( + dataset, batch_size_per_dataset[split] + ) + for split, dataset in dataset_dict.items() + } + + def _compute_batch_size(self): + """ + Compute an appropriate batch size for the given dataset. + """ + elements_per_dataset = { + dataset_name: len(dataset) + for dataset_name, dataset in self.dataset_dict.items() + } + total_elements = sum(el for el in elements_per_dataset.values()) + portion_per_dataset = { + name: el / total_elements + for name, el in elements_per_dataset.items() + } + batch_size_per_dataset = { + name: max(1, int(portion * self.batch_size)) + for name, portion in portion_per_dataset.items() + } + tot_el_per_batch = sum(el for el in batch_size_per_dataset.values()) + + if self.batch_size > tot_el_per_batch: + difference = self.batch_size - tot_el_per_batch + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] += 1 + difference -= 1 + if self.batch_size < tot_el_per_batch: + difference = tot_el_per_batch - self.batch_size + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] -= 1 + difference -= 1 + return batch_size_per_dataset + + def _create_dataloader(self, dataset, batch_size): + print(batch_size) + if batch_size is None: + return DummyDataloader(dataset) + + if not dataset.automatic_batching: + collate_fn = partial(collate_fn_custom, dataset=dataset) + else: + collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn) + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=self.num_workers, + collate_fn=collate_fn, + sampler=PinaSampler(dataset, shuffle=self.shuffle), + ) + + def __len__(self): + return max(len(dl) for dl in self.dataloaders.values()) + + def __iter__(self): + """ + Restituisce un iteratore che produce dizionari di batch. + + Itera per un numero di passi pari al dataloader più lungo (come da __len__) + e fa ricominciare i dataloader più corti quando si esauriscono. + """ + # 1. Crea un iteratore per ogni dataloader + iterators = {split: iter(dl) for split, dl in self.dataloaders.items()} + + # 2. Itera per il numero di batch del dataloader più lungo + for _ in range(len(self)): + + # 3. Prepara il dizionario di batch per questo step + batch_dict = {} + + # 4. Ottieni il prossimo batch da ogni iteratore + for split, it in iterators.items(): + try: + batch = next(it) + except StopIteration: + # 5. Se un iteratore è esaurito, resettalo e prendi il primo batch + new_it = iter(self.dataloaders[split]) + iterators[split] = new_it # Salva il nuovo iteratore + batch = next(new_it) + + batch_dict[split] = batch + + # 6. Restituisci il dizionario di batch + yield batch_dict diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 62e3913..d829770 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -1,9 +1,10 @@ """Module for the PINA dataset classes.""" -from abc import abstractmethod, ABC from torch.utils.data import Dataset from torch_geometric.data import Data from ..graph import Graph, LabelBatch +from ..label_tensor import LabelTensor +import torch class PinaDatasetFactory: @@ -41,286 +42,156 @@ class PinaDatasetFactory: if len(conditions_dict) == 0: raise ValueError("No conditions provided") + dataset_dict = {} + # Check is a Graph is present in the conditions - is_graph = cls._is_graph_dataset(conditions_dict) - if is_graph: - # If a Graph is present, return a PinaGraphDataset - return PinaGraphDataset(conditions_dict, **kwargs) - # If no Graph is present, return a PinaTensorDataset - return PinaTensorDataset(conditions_dict, **kwargs) + for name, data in conditions_dict.items(): + if not isinstance(data, dict): + raise ValueError( + f"Condition '{name}' data must be a dictionary" + ) + + # is_graph = cls._is_graph_dataset(conditions_dict) + # if is_graph: + # raise NotImplementedError("PinaGraphDataset is not implemented yet.") + + dataset_dict[name] = PinaTensorDataset(data, **kwargs) + return dataset_dict @staticmethod - def _is_graph_dataset(conditions_dict): + def _is_graph_dataset(cond_data): """ - Check if a graph is present in the conditions (at least one time). - - :param conditions_dict: Dictionary containing the conditions. - :type conditions_dict: dict - :return: True if a graph is present in the conditions, False otherwise. - :rtype: bool + TODO: Docstring """ - # Iterate over the conditions dictionary - for v in conditions_dict.values(): - # Iterate over the values of the current condition - for cond in v.values(): - # Check if the current value is a list of Data objects - if isinstance(cond, (Data, Graph, list, tuple)): - return True + # Iterate over the values of the current condition + for cond in cond_data.values(): + if isinstance(cond, (Data, Graph, list, tuple)): + return True return False -class PinaDataset(Dataset, ABC): +class PinaTensorDataset(Dataset): """ - Abstract class for the PINA dataset which extends the PyTorch - :class:`~torch.utils.data.Dataset` class. It defines the common interface - for :class:`~pina.data.dataset.PinaTensorDataset` and - :class:`~pina.data.dataset.PinaGraphDataset` classes. + Dataset class for the PINA dataset with :class:`torch.Tensor` and + :class:`~pina.label_tensor.LabelTensor` data. """ - def __init__( - self, conditions_dict, max_conditions_lengths, automatic_batching - ): + def __init__(self, data_dict, automatic_batching=None): """ - Initialize the instance by storing the conditions dictionary, the - maximum number of items per conditions to consider, and the automatic - batching flag. + Initialize the instance by storing the conditions dictionary. :param dict conditions_dict: A dictionary mapping condition names to their respective data. Each key represents a condition name, and the corresponding value is a dictionary containing the associated data. - :param dict max_conditions_lengths: Maximum number of data points that - can be included in a single batch per condition. - :param bool automatic_batching: Indicates whether PyTorch automatic - batching is enabled in - :class:`~pina.data.data_module.PinaDataModule`. """ # Store the conditions dictionary - self.conditions_dict = conditions_dict - # Store the maximum number of conditions to consider - self.max_conditions_lengths = max_conditions_lengths - # Store length of each condition - self.conditions_length = { - k: len(v["input"]) for k, v in self.conditions_dict.items() - } - # Store the maximum length of the dataset - self.length = max(self.conditions_length.values()) - # Dynamically set the getitem function based on automatic batching - if automatic_batching: - self._getitem_func = self._getitem_int - else: - self._getitem_func = self._getitem_dummy - - def _get_max_len(self): - """ - Returns the length of the longest condition in the dataset. - - :return: Length of the longest condition in the dataset. - :rtype: int - """ - - max_len = 0 - for condition in self.conditions_dict.values(): - max_len = max(max_len, len(condition["input"])) - return max_len + self.data = data_dict + self.automatic_batching = ( + automatic_batching if automatic_batching is not None else True + ) + self.stack_fn = ( + {} + ) # LabelTensor.stack if any(isinstance(v, LabelTensor) for v in data_dict.values()) else torch.stack + for k, v in data_dict.items(): + if isinstance(v, LabelTensor): + self.stack_fn[k] = LabelTensor.stack + elif isinstance(v, torch.Tensor): + self.stack_fn[k] = torch.stack + elif isinstance(v, list) and all( + isinstance(item, (Data, Graph)) for item in v + ): + self.stack_fn[k] = LabelBatch.from_data_list + else: + raise ValueError( + f"Unsupported data type for stacking: {type(v)}" + ) def __len__(self): - return self.length + return len(next(iter(self.data.values()))) def __getitem__(self, idx): - return self._getitem_func(idx) - - def _getitem_dummy(self, idx): """ - Return the index itself. This is used when automatic batching is - disabled to postpone the data retrieval to the dataloader. - - :param int idx: Index. - :return: Index. - :rtype: int - """ - - # If automatic batching is disabled, return the data at the given index - return idx - - def _getitem_int(self, idx): - """ - Return the data at the given index in the dataset. This is used when - automatic batching is enabled. + Return the data at the given index in the dataset. :param int idx: Index. :return: A dictionary containing the data at the given index. :rtype: dict """ - # If automatic batching is enabled, return the data at the given index - return { - k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()} - for k, v in self.conditions_dict.items() - } + if self.automatic_batching: + # Return the data at the given index + return { + field_name: data[idx] for field_name, data in self.data.items() + } + return idx - def get_all_data(self): - """ - Return all data in the dataset. - - :return: A dictionary containing all the data in the dataset. - :rtype: dict - """ - to_return_dict = {} - for condition, data in self.conditions_dict.items(): - len_condition = len( - data["input"] - ) # Length of the current condition - to_return_dict[condition] = self._retrive_data( - data, list(range(len_condition)) - ) # Retrieve the data from the current condition - return to_return_dict - - def fetch_from_idx_list(self, idx): + def _getitem_from_list(self, idx_list): """ Return data from the dataset given a list of indices. - :param list[int] idx: List of indices. + :param list[int] idx_list: List of indices. :return: A dictionary containing the data at the given indices. :rtype: dict """ - to_return_dict = {} - for condition, data in self.conditions_dict.items(): - # Get the indices for the current condition - cond_idx = idx[: self.max_conditions_lengths[condition]] - # Get the length of the current condition - condition_len = self.conditions_length[condition] - # If the length of the dataset is greater than the length of the - # current condition, repeat the indices - if self.length > condition_len: - cond_idx = [idx % condition_len for idx in cond_idx] - # Retrieve the data from the current condition - to_return_dict[condition] = self._retrive_data(data, cond_idx) - return to_return_dict - - @abstractmethod - def _retrive_data(self, data, idx_list): - """ - Abstract method to retrieve data from the dataset given a list of - indices. - """ - - -class PinaTensorDataset(PinaDataset): - """ - Dataset class for the PINA dataset with :class:`torch.Tensor` and - :class:`~pina.label_tensor.LabelTensor` data. - """ - - # Override _retrive_data method for torch.Tensor data - def _retrive_data(self, data, idx_list): - """ - Retrieve data from the dataset given a list of indices. - - :param dict data: Dictionary containing the data - (only :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor`). - :param list[int] idx_list: indices to retrieve. - :return: Dictionary containing the data at the given indices. - :rtype: dict - """ - - return {k: v[idx_list] for k, v in data.items()} - - @property - def input(self): - """ - Return the input data for the dataset. - - :return: Dictionary containing the input points. - :rtype: dict - """ - return {k: v["input"] for k, v in self.conditions_dict.items()} - - def update_data(self, new_conditions_dict): - """ - Update the dataset with new data. - This method is used to update the dataset with new data. It replaces - the current data with the new data provided in the new_conditions_dict - parameter. - - :param dict new_conditions_dict: Dictionary containing the new data. - :return: None - """ - for condition, data in new_conditions_dict.items(): - if condition in self.conditions_dict: - self.conditions_dict[condition].update(data) + to_return = {} + for field_name, data in self.data.items(): + if self.stack_fn[field_name] == LabelBatch.from_data_list: + to_return[field_name] = self.stack_fn[field_name]( + [data[i] for i in idx_list] + ) else: - self.conditions_dict[condition] = data + to_return[field_name] = data[idx_list] + return to_return -class PinaGraphDataset(PinaDataset): - """ - Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data` - and :class:`~pina.graph.Graph` data. - """ - - def _create_graph_batch(self, data): +class PinaGraphDataset(Dataset): + def __init__(self, data_dict, automatic_batching=None): """ - Create a LabelBatch object from a list of - :class:`~torch_geometric.data.Data` objects. + Initialize the instance by storing the conditions dictionary. - :param data: List of items to collate in a single batch. - :type data: list[Data] | list[Graph] - :return: LabelBatch object all the graph collated in a single batch - disconnected graphs. - :rtype: LabelBatch - """ - batch = LabelBatch.from_data_list(data) - return batch - - def create_batch(self, data): - """ - Create a Batch object from a list of :class:`~torch_geometric.data.Data` - objects. - - :param data: List of items to collate in a single batch. - :type data: list[Data] | list[Graph] - :return: Batch object. - :rtype: :class:`~torch_geometric.data.Batch` - | :class:`~pina.graph.LabelBatch` + :param dict conditions_dict: A dictionary mapping condition names to + their respective data. Each key represents a condition name, and the + corresponding value is a dictionary containing the associated data. """ - if isinstance(data[0], Data): - return self._create_graph_batch(data) - return self._create_tensor_batch(data) + # Store the conditions dictionary + self.data = data_dict + self.automatic_batching = ( + automatic_batching if automatic_batching is not None else True + ) - # Override _retrive_data method for graph handling - def _retrive_data(self, data, idx_list): + def __len__(self): + return len(next(iter(self.data.values()))) + + def __getitem__(self, idx): """ - Retrieve data from the dataset given a list of indices. + Return the data at the given index in the dataset. - :param dict data: Dictionary containing the data. - :param list[int] idx_list: List of indices to retrieve. - :return: Dictionary containing the data at the given indices. + :param int idx: Index. + :return: A dictionary containing the data at the given index. + :rtype: dict + """ + + if self.automatic_batching: + # Return the data at the given index + return { + field_name: data[idx] for field_name, data in self.data.items() + } + return idx + + def _getitem_from_list(self, idx_list): + """ + Return data from the dataset given a list of indices. + + :param list[int] idx_list: List of indices. + :return: A dictionary containing the data at the given indices. :rtype: dict """ - # Return the data from the current condition - # If the data is a list of Data objects, create a Batch object - # If the data is a list of torch.Tensor objects, create a torch.Tensor return { - k: ( - self._create_graph_batch([v[i] for i in idx_list]) - if isinstance(v, list) - else v[idx_list] - ) - for k, v in data.items() + field_name: [data[i] for i in idx_list] + for field_name, data in self.data.items() } - - @property - def input(self): - """ - Return the input data for the dataset. - - :return: Dictionary containing the input points. - :rtype: dict - """ - return {k: v["input"] for k, v in self.conditions_dict.items()}