2 Commits

Author SHA1 Message Date
FilippoOlivo
4d172a8821 fix data pipeline and add separeate_conditions option 2025-11-12 15:59:28 +01:00
FilippoOlivo
99e2f07cf7 refact dataset, dataloader and datamodule 2025-11-12 14:32:56 +01:00
4 changed files with 353 additions and 545 deletions

View File

@@ -4,4 +4,3 @@ __all__ = ["PinaDataModule", "PinaDataset"]
from .data_module import PinaDataModule from .data_module import PinaDataModule
from .dataset import PinaDataset

View File

@@ -7,232 +7,9 @@ different types of Datasets defined in PINA.
import warnings import warnings
from lightning.pytorch import LightningDataModule from lightning.pytorch import LightningDataModule
import torch import torch
from torch_geometric.data import Data
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory, PinaTensorDataset from .dataset import PinaDatasetFactory
from .dataloader import PinaDataLoader
class DummyDataloader:
def __init__(self, dataset):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
This dataloader is used when the batch size is ``None``.
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
idx, i = [], rank
while i < len(dataset):
idx.append(i)
i += world_size
self.dataset = dataset.fetch_from_idx_list(idx)
else:
self.dataset = dataset.get_all_data()
def __iter__(self):
return self
def __len__(self):
return 1
def __next__(self):
return self.dataset
class Collator:
"""
This callable class is used to collate the data points fetched from the
dataset. The collation is performed based on the type of dataset used and
on the batching strategy.
"""
def __init__(
self, max_conditions_lengths, automatic_batching, dataset=None
):
"""
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: ``dict`` containing the maximum
number of data points to consider in a single batch for
each condition.
:param bool automatic_batching: Whether automatic PyTorch batching is
enabled or not. For more information, see the
:class:`~pina.data.data_module.PinaDataModule` class.
:param PinaDataset dataset: The dataset where the data is stored.
"""
self.max_conditions_lengths = max_conditions_lengths
# Set the collate function based on the batching strategy
# collate_pina_dataloader is used when automatic batching is disabled
# collate_torch_dataloader is used when automatic batching is enabled
self.callable_function = (
self._collate_torch_dataloader
if automatic_batching
else (self._collate_pina_dataloader)
)
self.dataset = dataset
# Set the function which performs the actual collation
if isinstance(self.dataset, PinaTensorDataset):
# If the dataset is a PinaTensorDataset, use this collate function
self._collate = self._collate_tensor_dataset
else:
# If the dataset is a PinaDataset, use this collate function
self._collate = self._collate_graph_dataset
def _collate_pina_dataloader(self, batch):
"""
Function used to create a batch when automatic batching is disabled.
:param list[int] batch: List of integers representing the indices of
the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset.
:rtype: dict
"""
# Call the fetch_from_idx_list method of the dataset
return self.dataset.fetch_from_idx_list(batch)
def _collate_torch_dataloader(self, batch):
"""
Function used to collate the batch
:param list[dict] batch: List of retrieved data.
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
batch_dict = {}
if isinstance(batch, dict):
return batch
conditions_names = batch[0].keys()
# Condition names
for condition_name in conditions_names:
single_cond_dict = {}
condition_args = batch[0][condition_name].keys()
for arg in condition_args:
data_list = [
batch[idx][condition_name][arg]
for idx in range(
min(
len(batch),
self.max_conditions_lengths[condition_name],
)
)
]
single_cond_dict[arg] = self._collate(data_list)
batch_dict[condition_name] = single_cond_dict
return batch_dict
@staticmethod
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
:class:`~pina.data.dataset.PinaTensorDataset`.
:param data_list: Elements to be collated.
:type data_list: list[torch.Tensor] | list[LabelTensor]
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor`.
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.stack(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list):
"""
Function used to collate data when the dataset is a
:class:`~pina.data.dataset.PinaGraphDataset`.
:param data_list: Elememts to be collated.
:type data_list: list[Data] | list[Graph]
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_batch(data_list)
raise RuntimeError(
"Data must be Tensors or LabelTensor or pyG "
"torch_geometric.data.Data"
)
def __call__(self, batch):
"""
Perform the collation of data fetched from the dataset. The behavoior
of the function is set based on the batching strategy during class
initialization.
:param batch: List of retrieved data or sampled indices.
:type batch: list[int] | list[dict]
:return: Dictionary containing colleted data fetched from the dataset.
:rtype: dict
"""
return self.callable_function(batch)
class PinaSampler:
"""
This class is used to create the sampler instance based on the shuffle
parameter and the environment in which the code is running.
"""
def __new__(cls, dataset):
"""
Instantiate and initialize the sampler.
:param PinaDataset dataset: The dataset from which to sample.
:return: The sampler instance.
:rtype: :class:`torch.utils.data.Sampler`
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
sampler = DistributedSampler(dataset)
else:
sampler = SequentialSampler(dataset)
return sampler
class PinaDataModule(LightningDataModule): class PinaDataModule(LightningDataModule):
@@ -250,7 +27,8 @@ class PinaDataModule(LightningDataModule):
val_size=0.1, val_size=0.1,
batch_size=None, batch_size=None,
shuffle=True, shuffle=True,
repeat=False, common_batch_size=True,
separate_conditions=False,
automatic_batching=None, automatic_batching=None,
num_workers=0, num_workers=0,
pin_memory=False, pin_memory=False,
@@ -271,11 +49,12 @@ class PinaDataModule(LightningDataModule):
Default is ``None``. Default is ``None``.
:param bool shuffle: Whether to shuffle the dataset before splitting. :param bool shuffle: Whether to shuffle the dataset before splitting.
Default ``True``. Default ``True``.
:param bool repeat: If ``True``, in case of batch size larger than the :param bool common_batch_size: If ``True``, the same batch size is used
number of elements in a specific condition, the elements are for all conditions. If ``False``, each condition can have its own
repeated until the batch size is reached. If ``False``, the number batch size, proportional to the size of the dataset in that
of elements in the batch is the minimum between the batch size and condition. Default is ``True``.
the number of elements in the condition. Default is ``False``. :param bool separate_conditions: If ``True``, dataloaders for each
condition are iterated separately. Default is ``False``.
:param automatic_batching: If ``True``, automatic PyTorch batching :param automatic_batching: If ``True``, automatic PyTorch batching
is performed, which consists of extracting one element at a time is performed, which consists of extracting one element at a time
from the dataset and collating them into a batch. This is useful from the dataset and collating them into a batch. This is useful
@@ -305,7 +84,8 @@ class PinaDataModule(LightningDataModule):
# Store fixed attributes # Store fixed attributes
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
self.repeat = repeat self.common_batch_size = common_batch_size
self.separate_conditions = separate_conditions
self.automatic_batching = automatic_batching self.automatic_batching = automatic_batching
# If batch size is None, num_workers has no effect # If batch size is None, num_workers has no effect
@@ -376,23 +156,16 @@ class PinaDataModule(LightningDataModule):
if stage == "fit" or stage is None: if stage == "fit" or stage is None:
self.train_dataset = PinaDatasetFactory( self.train_dataset = PinaDatasetFactory(
self.data_splits["train"], self.data_splits["train"],
max_conditions_lengths=self.find_max_conditions_lengths(
"train"
),
automatic_batching=self.automatic_batching, automatic_batching=self.automatic_batching,
) )
if "val" in self.data_splits.keys(): if "val" in self.data_splits.keys():
self.val_dataset = PinaDatasetFactory( self.val_dataset = PinaDatasetFactory(
self.data_splits["val"], self.data_splits["val"],
max_conditions_lengths=self.find_max_conditions_lengths(
"val"
),
automatic_batching=self.automatic_batching, automatic_batching=self.automatic_batching,
) )
elif stage == "test": elif stage == "test":
self.test_dataset = PinaDatasetFactory( self.test_dataset = PinaDatasetFactory(
self.data_splits["test"], self.data_splits["test"],
max_conditions_lengths=self.find_max_conditions_lengths("test"),
automatic_batching=self.automatic_batching, automatic_batching=self.automatic_batching,
) )
else: else:
@@ -502,53 +275,15 @@ class PinaDataModule(LightningDataModule):
), ),
module="lightning.pytorch.trainer.connectors.data_connector", module="lightning.pytorch.trainer.connectors.data_connector",
) )
# Use custom batching (good if batch size is large) return PinaDataLoader(
if self.batch_size is not None: dataset,
sampler = PinaSampler(dataset) batch_size=self.batch_size,
if self.automatic_batching: shuffle=self.shuffle,
collate = Collator( num_workers=self.num_workers,
self.find_max_conditions_lengths(split), collate_fn=None,
self.automatic_batching, common_batch_size=self.common_batch_size,
dataset=dataset, separate_conditions=self.separate_conditions,
)
else:
collate = Collator(
None, self.automatic_batching, dataset=dataset
)
return DataLoader(
dataset,
self.batch_size,
collate_fn=collate,
sampler=sampler,
num_workers=self.num_workers,
)
dataloader = DummyDataloader(dataset)
dataloader.dataset = self._transfer_batch_to_device(
dataloader.dataset, self.trainer.strategy.root_device, 0
) )
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader
def find_max_conditions_lengths(self, split):
"""
Define the maximum length for each conditions.
:param dict split: The split of the dataset.
:return: The maximum length per condition.
:rtype: dict
"""
max_conditions_lengths = {}
for k, v in self.data_splits[split].items():
if self.batch_size is None:
max_conditions_lengths[k] = len(v["input"])
elif self.repeat:
max_conditions_lengths[k] = self.batch_size
else:
max_conditions_lengths[k] = min(
len(v["input"]), self.batch_size
)
return max_conditions_lengths
def val_dataloader(self): def val_dataloader(self):
""" """

242
pina/data/dataloader.py Normal file
View File

@@ -0,0 +1,242 @@
from torch.utils.data import DataLoader
from functools import partial
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
import torch
class DummyDataloader:
def __init__(self, dataset):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
This dataloader is used when the batch size is ``None``.
"""
print("Using DummyDataloader")
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
idx, i = [], rank
while i < len(dataset):
idx.append(i)
i += world_size
else:
idx = list(range(len(dataset)))
self.dataset = dataset._getitem_from_list(idx)
def __iter__(self):
return self
def __len__(self):
return 1
def __next__(self):
return self.dataset
class PinaSampler:
"""
This class is used to create the sampler instance based on the shuffle
parameter and the environment in which the code is running.
"""
def __new__(cls, dataset, shuffle=True):
"""
Instantiate and initialize the sampler.
:param PinaDataset dataset: The dataset from which to sample.
:return: The sampler instance.
:rtype: :class:`torch.utils.data.Sampler`
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
sampler = DistributedSampler(dataset, shuffle=shuffle)
else:
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
return sampler
def _collect_items(batch):
"""
Helper function to collect items from a batch of graph data samples.
:param batch: List of graph data samples.
"""
to_return = {name: [] for name in batch[0].keys()}
for sample in batch:
for k, v in sample.items():
to_return[k].append(v)
return to_return
def collate_fn_custom(batch, dataset):
"""
Override the default collate function to handle datasets without automatic batching.
:param batch: List of indices from the dataset.
:param dataset: The PinaDataset instance (must be provided).
"""
return dataset._getitem_from_list(batch)
def collate_fn_default(batch, stack_fn):
"""
Default collate function that simply returns the batch as is.
:param batch: List of data samples.
"""
print("Using default collate function")
to_return = _collect_items(batch)
return {k: stack_fn[k](v) for k, v in to_return.items()}
class PinaDataLoader:
"""
Custom DataLoader for PinaDataset.
"""
def __init__(
self,
dataset_dict,
batch_size,
shuffle=False,
num_workers=0,
collate_fn=None,
common_batch_size=True,
separate_conditions=False,
):
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.collate_fn = collate_fn
self.separate_conditions = separate_conditions
if batch_size is None:
batch_size_per_dataset = {
split: None for split in dataset_dict.keys()
}
else:
if common_batch_size:
batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys()
}
else:
batch_size_per_dataset = self._compute_batch_size()
self.dataloaders = {
split: self._create_dataloader(
dataset, batch_size_per_dataset[split]
)
for split, dataset in dataset_dict.items()
}
def _compute_batch_size(self):
"""
Compute an appropriate batch size for the given dataset.
"""
elements_per_dataset = {
dataset_name: len(dataset)
for dataset_name, dataset in self.dataset_dict.items()
}
total_elements = sum(el for el in elements_per_dataset.values())
portion_per_dataset = {
name: el / total_elements
for name, el in elements_per_dataset.items()
}
batch_size_per_dataset = {
name: max(1, int(portion * self.batch_size))
for name, portion in portion_per_dataset.items()
}
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
if self.batch_size > tot_el_per_batch:
difference = self.batch_size - tot_el_per_batch
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] += 1
difference -= 1
if self.batch_size < tot_el_per_batch:
difference = tot_el_per_batch - self.batch_size
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] -= 1
difference -= 1
return batch_size_per_dataset
def _create_dataloader(self, dataset, batch_size):
print(batch_size)
if batch_size is None:
return DummyDataloader(dataset)
if not dataset.automatic_batching:
collate_fn = partial(collate_fn_custom, dataset=dataset)
else:
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
sampler=PinaSampler(dataset, shuffle=self.shuffle),
)
def __len__(self):
if self.separate_conditions:
return sum(len(dl) for dl in self.dataloaders.values())
return max(len(dl) for dl in self.dataloaders.values())
def __iter__(self):
"""
Restituisce un iteratore che produce dizionari di batch.
Itera per un numero di passi pari al dataloader più lungo (come da __len__)
e fa ricominciare i dataloader più corti quando si esauriscono.
"""
if self.separate_conditions:
for split, dl in self.dataloaders.items():
for batch in dl:
yield {split: batch}
return
iterators = {split: iter(dl) for split, dl in self.dataloaders.items()}
for _ in range(len(self)):
batch_dict = {}
for split, it in iterators.items():
try:
batch = next(it)
except StopIteration:
new_it = iter(self.dataloaders[split])
iterators[split] = new_it
batch = next(new_it)
batch_dict[split] = batch
yield batch_dict

View File

@@ -1,326 +1,158 @@
"""Module for the PINA dataset classes.""" """Module for the PINA dataset classes."""
from abc import abstractmethod, ABC import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torch_geometric.data import Data from torch_geometric.data import Data
from ..graph import Graph, LabelBatch from ..graph import Graph, LabelBatch
from ..label_tensor import LabelTensor
class PinaDatasetFactory: class PinaDatasetFactory:
""" """
Factory class for the PINA dataset. TODO: Update docstring
Depending on the data type inside the conditions, it instanciate an object
belonging to the appropriate subclass of
:class:`~pina.data.dataset.PinaDataset`. The possible subclasses are:
- :class:`~pina.data.dataset.PinaTensorDataset`, for handling \
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
- :class:`~pina.data.dataset.PinaGraphDataset`, for handling \
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
""" """
def __new__(cls, conditions_dict, **kwargs): def __new__(cls, conditions_dict, **kwargs):
""" """
Instantiate the appropriate subclass of TODO: Update docstring
:class:`~pina.data.dataset.PinaDataset`.
If a graph is present in the conditions, returns a
:class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
:class:`~pina.data.dataset.PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance.
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
:rtype: PinaTensorDataset | PinaGraphDataset
:raises ValueError: If an empty dictionary is provided.
""" """
# Check if conditions_dict is empty # Check if conditions_dict is empty
if len(conditions_dict) == 0: if len(conditions_dict) == 0:
raise ValueError("No conditions provided") raise ValueError("No conditions provided")
dataset_dict = {}
# Check is a Graph is present in the conditions # Check is a Graph is present in the conditions
is_graph = cls._is_graph_dataset(conditions_dict) for name, data in conditions_dict.items():
if is_graph: if not isinstance(data, dict):
# If a Graph is present, return a PinaGraphDataset raise ValueError(
return PinaGraphDataset(conditions_dict, **kwargs) f"Condition '{name}' data must be a dictionary"
# If no Graph is present, return a PinaTensorDataset )
return PinaTensorDataset(conditions_dict, **kwargs) dataset_dict[name] = PinaDataset(data, **kwargs)
return dataset_dict
@staticmethod
def _is_graph_dataset(conditions_dict):
"""
Check if a graph is present in the conditions (at least one time).
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise.
:rtype: bool
"""
# Iterate over the conditions dictionary
for v in conditions_dict.values():
# Iterate over the values of the current condition
for cond in v.values():
# Check if the current value is a list of Data objects
if isinstance(cond, (Data, Graph, list, tuple)):
return True
return False
class PinaDataset(Dataset, ABC): class PinaDataset(Dataset):
""" """
Abstract class for the PINA dataset which extends the PyTorch Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`~torch.utils.data.Dataset` class. It defines the common interface :class:`~pina.label_tensor.LabelTensor` data.
for :class:`~pina.data.dataset.PinaTensorDataset` and
:class:`~pina.data.dataset.PinaGraphDataset` classes.
""" """
def __init__( def __init__(self, data_dict, automatic_batching=None):
self, conditions_dict, max_conditions_lengths, automatic_batching
):
""" """
Initialize the instance by storing the conditions dictionary, the Initialize the instance by storing the conditions dictionary.
maximum number of items per conditions to consider, and the automatic
batching flag.
:param dict conditions_dict: A dictionary mapping condition names to :param dict conditions_dict: A dictionary mapping condition names to
their respective data. Each key represents a condition name, and the their respective data. Each key represents a condition name, and the
corresponding value is a dictionary containing the associated data. corresponding value is a dictionary containing the associated data.
:param dict max_conditions_lengths: Maximum number of data points that
can be included in a single batch per condition.
:param bool automatic_batching: Indicates whether PyTorch automatic
batching is enabled in
:class:`~pina.data.data_module.PinaDataModule`.
""" """
# Store the conditions dictionary # Store the conditions dictionary
self.conditions_dict = conditions_dict self.data = data_dict
# Store the maximum number of conditions to consider self.automatic_batching = (
self.max_conditions_lengths = max_conditions_lengths automatic_batching if automatic_batching is not None else True
# Store length of each condition )
self.conditions_length = { self.stack_fn = {}
k: len(v["input"]) for k, v in self.conditions_dict.items() # Determine stacking functions for each data type (used in collate_fn)
} for k, v in data_dict.items():
# Store the maximum length of the dataset if isinstance(v, LabelTensor):
self.length = max(self.conditions_length.values()) self.stack_fn[k] = LabelTensor.stack
# Dynamically set the getitem function based on automatic batching elif isinstance(v, torch.Tensor):
if automatic_batching: self.stack_fn[k] = torch.stack
self._getitem_func = self._getitem_int elif isinstance(v, list) and all(
else: isinstance(item, (Data, Graph)) for item in v
self._getitem_func = self._getitem_dummy ):
self.stack_fn[k] = LabelBatch.from_data_list
def _get_max_len(self): else:
""" raise ValueError(
Returns the length of the longest condition in the dataset. f"Unsupported data type for stacking: {type(v)}"
)
:return: Length of the longest condition in the dataset.
:rtype: int
"""
max_len = 0
for condition in self.conditions_dict.values():
max_len = max(max_len, len(condition["input"]))
return max_len
def __len__(self): def __len__(self):
return self.length return len(next(iter(self.data.values())))
def __getitem__(self, idx): def __getitem__(self, idx):
return self._getitem_func(idx)
def _getitem_dummy(self, idx):
""" """
Return the index itself. This is used when automatic batching is Return the data at the given index in the dataset.
disabled to postpone the data retrieval to the dataloader.
:param int idx: Index.
:return: Index.
:rtype: int
"""
# If automatic batching is disabled, return the data at the given index
return idx
def _getitem_int(self, idx):
"""
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index. :param int idx: Index.
:return: A dictionary containing the data at the given index. :return: A dictionary containing the data at the given index.
:rtype: dict :rtype: dict
""" """
# If automatic batching is enabled, return the data at the given index if self.automatic_batching:
return { # Return the data at the given index
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()} return {
for k, v in self.conditions_dict.items() field_name: data[idx] for field_name, data in self.data.items()
} }
return idx
def get_all_data(self): def _getitem_from_list(self, idx_list):
"""
Return all data in the dataset.
:return: A dictionary containing all the data in the dataset.
:rtype: dict
"""
to_return_dict = {}
for condition, data in self.conditions_dict.items():
len_condition = len(
data["input"]
) # Length of the current condition
to_return_dict[condition] = self._retrive_data(
data, list(range(len_condition))
) # Retrieve the data from the current condition
return to_return_dict
def fetch_from_idx_list(self, idx):
""" """
Return data from the dataset given a list of indices. Return data from the dataset given a list of indices.
:param list[int] idx: List of indices. :param list[int] idx_list: List of indices.
:return: A dictionary containing the data at the given indices. :return: A dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
""" """
to_return_dict = {} to_return = {}
for condition, data in self.conditions_dict.items(): for field_name, data in self.data.items():
# Get the indices for the current condition if self.stack_fn[field_name] == LabelBatch.from_data_list:
cond_idx = idx[: self.max_conditions_lengths[condition]] to_return[field_name] = self.stack_fn[field_name](
# Get the length of the current condition [data[i] for i in idx_list]
condition_len = self.conditions_length[condition] )
# If the length of the dataset is greater than the length of the
# current condition, repeat the indices
if self.length > condition_len:
cond_idx = [idx % condition_len for idx in cond_idx]
# Retrieve the data from the current condition
to_return_dict[condition] = self._retrive_data(data, cond_idx)
return to_return_dict
@abstractmethod
def _retrive_data(self, data, idx_list):
"""
Abstract method to retrieve data from the dataset given a list of
indices.
"""
class PinaTensorDataset(PinaDataset):
"""
Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`~pina.label_tensor.LabelTensor` data.
"""
# Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices.
:param dict data: Dictionary containing the data
(only :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor`).
:param list[int] idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
"""
return {k: v[idx_list] for k, v in data.items()}
@property
def input(self):
"""
Return the input data for the dataset.
:return: Dictionary containing the input points.
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}
def update_data(self, new_conditions_dict):
"""
Update the dataset with new data.
This method is used to update the dataset with new data. It replaces
the current data with the new data provided in the new_conditions_dict
parameter.
:param dict new_conditions_dict: Dictionary containing the new data.
:return: None
"""
for condition, data in new_conditions_dict.items():
if condition in self.conditions_dict:
self.conditions_dict[condition].update(data)
else: else:
self.conditions_dict[condition] = data to_return[field_name] = data[idx_list]
return to_return
class PinaGraphDataset(PinaDataset): class PinaGraphDataset(Dataset):
""" def __init__(self, data_dict, automatic_batching=None):
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
and :class:`~pina.graph.Graph` data.
"""
def _create_graph_batch(self, data):
""" """
Create a LabelBatch object from a list of Initialize the instance by storing the conditions dictionary.
:class:`~torch_geometric.data.Data` objects.
:param data: List of items to collate in a single batch. :param dict conditions_dict: A dictionary mapping condition names to
:type data: list[Data] | list[Graph] their respective data. Each key represents a condition name, and the
:return: LabelBatch object all the graph collated in a single batch corresponding value is a dictionary containing the associated data.
disconnected graphs.
:rtype: LabelBatch
"""
batch = LabelBatch.from_data_list(data)
return batch
def create_batch(self, data):
"""
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects.
:param data: List of items to collate in a single batch.
:type data: list[Data] | list[Graph]
:return: Batch object.
:rtype: :class:`~torch_geometric.data.Batch`
| :class:`~pina.graph.LabelBatch`
""" """
if isinstance(data[0], Data): # Store the conditions dictionary
return self._create_graph_batch(data) self.data = data_dict
return self._create_tensor_batch(data) self.automatic_batching = (
automatic_batching if automatic_batching is not None else True
)
# Override _retrive_data method for graph handling def __len__(self):
def _retrive_data(self, data, idx_list): return len(next(iter(self.data.values())))
def __getitem__(self, idx):
""" """
Retrieve data from the dataset given a list of indices. Return the data at the given index in the dataset.
:param dict data: Dictionary containing the data. :param int idx: Index.
:param list[int] idx_list: List of indices to retrieve. :return: A dictionary containing the data at the given index.
:return: Dictionary containing the data at the given indices. :rtype: dict
"""
if self.automatic_batching:
# Return the data at the given index
return {
field_name: data[idx] for field_name, data in self.data.items()
}
return idx
def _getitem_from_list(self, idx_list):
"""
Return data from the dataset given a list of indices.
:param list[int] idx_list: List of indices.
:return: A dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
""" """
# Return the data from the current condition
# If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor
return { return {
k: ( field_name: [data[i] for i in idx_list]
self._create_graph_batch([v[i] for i in idx_list]) for field_name, data in self.data.items()
if isinstance(v, list)
else v[idx_list]
)
for k, v in data.items()
} }
@property
def input(self):
"""
Return the input data for the dataset.
:return: Dictionary containing the input points.
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}