This commit is contained in:
FilippoOlivo
2025-03-11 17:18:48 +01:00
committed by Nicola Demo
parent 99d2e70f4a
commit cbbaa4062f
3 changed files with 193 additions and 106 deletions

View File

@@ -1,5 +1,5 @@
"""
Import data classes
Module for data data module and dataset.
"""
__all__ = ["PinaDataModule", "PinaDataset"]

View File

@@ -16,16 +16,16 @@ from ..collector import Collector
class DummyDataloader:
""" "
Dummy dataloader used when batch size is None. It callects all the data
in self.dataset and returns it when it is called a single batch.
"""
Dataloader used when batch size is ``None``. It returns the entire dataset
in a single batch.
"""
def __init__(self, dataset):
"""
Preprare a dataloader object which will return the entire dataset
in a single batch. Depending on the number of GPUs, the dataset we
have the following cases:
in a single batch. Depending on the number of GPUs, the dataset is
managed as follows:
- **Distributed Environment** (multiple GPUs):
- Divides the dataset across processes using the rank and world
@@ -38,7 +38,7 @@ class DummyDataloader:
:param dataset: The dataset object to be processed.
:type dataset: PinaDataset
.. note:: This data loader is used when the batch size is None.
.. note:: This data loader is used when the batch size is ``None``.
"""
if (
@@ -72,7 +72,9 @@ class DummyDataloader:
class Collator:
"""
Class used to collate the batch
This callable class is used to collate the data points fetched from the
dataset. The collation is performed based on the type of dataset used and
on the batching strategy.
"""
def __init__(
@@ -121,7 +123,13 @@ class Collator:
def _collate_torch_dataloader(self, batch):
"""
Function used to collate the batch
:param list(dict) batch: List of retrieved data.
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
batch_dict = {}
if isinstance(batch, dict):
return batch
@@ -149,15 +157,15 @@ class Collator:
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
`PinaTensorDataset`.
:class:`PinaTensorDataset`.
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
collated.
:param data_list: Elements to be collated.
:type data_list: list(torch.Tensor) | list(LabelTensor)
:raises RuntimeError: If the data is not a `torch.Tensor` or a
`LabelTensor`.
:return: Batch of data
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
:class:`LabelTensor`.
"""
if isinstance(data_list[0], LabelTensor):
@@ -169,13 +177,15 @@ class Collator:
def _collate_graph_dataset(self, data_list):
"""
Function used to collate the data when the dataset is a
`PinaGraphDataset`.
:class:`PinaGraphDataset`.
:param data_list: List of `Data` or `Graph` to be collated.
:type data_list: list(Data) | list(Graph)
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
:return: Batch of data
:param data_list: Elememts to be collated.
:type data_list: list(torch_geometric.data.Data) | list(Graph)
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a
:class:`torch_geometric.data.Data` or a :class:`Graph`.
"""
if isinstance(data_list[0], LabelTensor):
@@ -184,13 +194,18 @@ class Collator:
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_batch(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
raise RuntimeError(
"Data must be Tensors or LabelTensor or pyG "
"torch_geometric.data.Data"
)
def __call__(self, batch):
"""
Call the function to collate the batch, defined in __init__.
Perform the collation of the data points fetched from the dataset.
The behavoior of the function is set based on the batching strategy
during class initialization.
:param batch: list of indices or list of retrieved data
:param batch: List of retrieved data or sampled indices.
:type batch: list(int) | list(dict)
:return: Dictionary containing the data points fetched from the dataset,
collated.
@@ -202,13 +217,14 @@ class Collator:
class PinaSampler:
"""
Class used to create the sampler instance.
This class is used to create the sampler instance based on the shuffle
parameter and the environment in which the code is running.
"""
def __new__(cls, dataset, shuffle):
"""
Create the sampler instance, according to shuffle and whether the
environment is distributed or not.
Instantiate the sampler based on the environment in which the code is
running.
:param PinaDataset dataset: The dataset to be sampled.
:param bool shuffle: whether to shuffle the dataset or not before
@@ -232,8 +248,9 @@ class PinaSampler:
class PinaDataModule(LightningDataModule):
"""
This class extend LightningDataModule, allowing proper creation and
management of different types of Datasets defined in PINA
This class extends :class:`pytorch_lightning.LightningDataModule`,
allowing proper creation and management of different types of datasets
defined in PINA.
"""
def __init__(
@@ -253,23 +270,31 @@ class PinaDataModule(LightningDataModule):
Initialize the object, creating datasets based on the input problem.
:param AbstractProblem problem: The problem containing the data on which
to train/test the model.
to create the datasets and dataloaders.
:param float train_size: Fraction or number of elements in the training
split.
split. It must be in the range [0, 1].
:param float test_size: Fraction or number of elements in the test
split.
split. It must be in the range [0, 1].
:param float val_size: Fraction or number of elements in the validation
split.
split. It must be in the range [0, 1].
:param batch_size: The batch size used for training. If `None`, the
entire dataset is used per batch.
:type batch_size: int | None
:param bool shuffle: Whether to shuffle the dataset before splitting.
Default True.
:param bool repeat: Whether to repeat the dataset indefinitely.
Default False.
:param automatic_batching: Whether to enable automatic batching.
Default False.
:param int num_workers: Number of worker threads for data loading.
Default 0 (serial loading)
Default 0 (serial loading). For more information, see
https://pytorch.org/docs/stable/data.html#multi-process-data-loading
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. (Default False)
transfer to GPU. (Default False). For more information, see
https://pytorch.org/docs/stable/data.html#memory-pinning
:raises ValueError: If at least one of the splits is negative.
:raises ValueError: If the sum of the splits is different from 1.
"""
super().__init__()
@@ -278,6 +303,8 @@ class PinaDataModule(LightningDataModule):
self.shuffle = shuffle
self.repeat = repeat
self.automatic_batching = automatic_batching
# If batch size is None, num_workers has no effect
if batch_size is None and num_workers != 0:
warnings.warn(
"Setting num_workers when batch_size is None has no effect on "
@@ -286,6 +313,8 @@ class PinaDataModule(LightningDataModule):
self.num_workers = 0
else:
self.num_workers = num_workers
# If batch size is None, pin_memory has no effect
if batch_size is None and pin_memory:
warnings.warn(
"Setting pin_memory to True has no effect when "
@@ -309,16 +338,22 @@ class PinaDataModule(LightningDataModule):
splits_dict["train"] = train_size
self.train_dataset = None
else:
# Use the super method to create the train dataloader which
# raises NotImplementedError
self.train_dataloader = super().train_dataloader
if test_size > 0:
splits_dict["test"] = test_size
self.test_dataset = None
else:
# Use the super method to create the train dataloader which
# raises NotImplementedError
self.test_dataloader = super().test_dataloader
if val_size > 0:
splits_dict["val"] = val_size
self.val_dataset = None
else:
# Use the super method to create the train dataloader which
# raises NotImplementedError
self.val_dataloader = super().val_dataloader
self.collector_splits = self._create_splits(collector, splits_dict)
@@ -326,7 +361,13 @@ class PinaDataModule(LightningDataModule):
def setup(self, stage=None):
"""
Perform the splitting of the dataset
Create the dataset objects for the given stage.
If the stage is "fit", the training and validation datasets are created.
If the stage is "test", the testing dataset is created.
:param str stage: The stage for which to perform the splitting.
:raises ValueError: If the stage is neither "fit" nor "test".
"""
if stage == "fit" or stage is None:
self.train_dataset = PinaDatasetFactory(
@@ -354,8 +395,18 @@ class PinaDataModule(LightningDataModule):
raise ValueError("stage must be either 'fit' or 'test'.")
@staticmethod
def _split_condition(condition_dict, splits_dict):
len_condition = len(condition_dict["input"])
def _split_condition(single_condition_dict, splits_dict):
"""
Split the condition into different stages.
:param dict single_condition_dict: The condition to be split.
:param dict splits_dict: The dictionary containing the number of
elements in each stage.
:return: A dictionary containing the split condition.
:rtype: dict
"""
len_condition = len(single_condition_dict["input"])
lengths = [
int(len_condition * length) for length in splits_dict.values()
@@ -374,7 +425,7 @@ class PinaDataModule(LightningDataModule):
for stage, stage_len in splits_dict.items():
to_return_dict[stage] = {
k: v[offset : offset + stage_len]
for k, v in condition_dict.items()
for k, v in single_condition_dict.items()
if k != "equation"
# Equations are NEVER dataloaded
}
@@ -386,7 +437,13 @@ class PinaDataModule(LightningDataModule):
def _create_splits(self, collector, splits_dict):
"""
Create the dataset objects putting data
Create the dataset objects putting data in the correct splits.
:param Collector collector: The collector object containing the data.
:param dict splits_dict: The dictionary containing the number of
elements in each stage.
:return: The dictionary containing the dataset objects.
:rtype: dict
"""
# ----------- Auxiliary function ------------
@@ -422,6 +479,15 @@ class PinaDataModule(LightningDataModule):
return dataset_dict
def _create_dataloader(self, split, dataset):
""" "
Create the dataloader for the given split.
:param str split: The split on which to create the dataloader.
:param str dataset: The dataset to be used for the dataloader.
:return: The dataloader for the given split.
:rtype: torch.utils.data.DataLoader
"""
shuffle = self.shuffle if split == "train" else False
# Suppress the warning about num_workers.
# In many cases, especially for PINNs,
@@ -470,6 +536,7 @@ class PinaDataModule(LightningDataModule):
:return: The maximum length of the conditions.
:rtype: dict
"""
max_conditions_lengths = {}
for k, v in self.collector_splits[split].items():
if self.batch_size is None:
@@ -484,7 +551,10 @@ class PinaDataModule(LightningDataModule):
def val_dataloader(self):
"""
Create the validation dataloader
Create the validation dataloader.
:return: The validation dataloader
:rtype: DataLoader
"""
return self._create_dataloader("val", self.val_dataset)
@@ -509,20 +579,17 @@ class PinaDataModule(LightningDataModule):
@staticmethod
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
This method is used when the batch size is None: batch has already
been transferred to the device.
Transfer the batch to the device. This method is used when the batch
size is None: batch has already been transferred to the device.
:param list(tuple) batch: list of tuple where the first element of the
tuple is the condition name and the second element is the data.
:param device: device to which the batch is transferred
:type device: torch.device
:param dataloader_idx: index of the dataloader
:type dataloader_idx: int
:param torch.device device: device to which the batch is transferred.
:param int dataloader_idx: index of the dataloader.
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
return batch
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
@@ -531,12 +598,13 @@ class PinaDataModule(LightningDataModule):
training loop and is used to transfer the batch to the device.
:param dict batch: The batch to be transferred to the device.
:param device: The device to which the batch is transferred.
:type device: torch.device
:param torch.device device: The device to which the batch is
transferred.
:param int dataloader_idx: The index of the dataloader.
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
batch = [
(
k,
@@ -552,8 +620,18 @@ class PinaDataModule(LightningDataModule):
@staticmethod
def _check_slit_sizes(train_size, test_size, val_size):
"""
Check if the splits are correct
Check if the splits are correct. The splits sizes must be positive and
the sum of the splits must be 1.
:param float train_size: The size of the training split.
:param float test_size: The size of the testing split.
:param float val_size: The size of the validation split.
:raises ValueError: If at least one of the splits is negative.
:raises ValueError: If the sum of the splits is different
from 1.
"""
if train_size < 0 or test_size < 0 or val_size < 0:
raise ValueError("The splits must be positive")
if abs(train_size + test_size + val_size - 1) > 1e-6:
@@ -567,6 +645,7 @@ class PinaDataModule(LightningDataModule):
:return: The input points for training.
:rtype dict
"""
to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None:
to_return["train"] = self.train_dataset.input

View File

@@ -1,8 +1,8 @@
"""
This module provide basic data management functionalities
Module for the PINA dataset
"""
from abc import abstractmethod
from abc import abstractmethod, ABC
from torch.utils.data import Dataset
from torch_geometric.data import Data
from ..graph import Graph, LabelBatch
@@ -15,9 +15,10 @@ class PinaDatasetFactory:
Depending on the type inside the conditions, it creates a different dataset
object:
- :class:`PinaTensorDataset` for `torch.Tensor`
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
objects
- :class:`PinaTensorDataset` for handling :class:`torch.Tensor` and
:class:`LabelTensor` data.
- :class:`PinaGraphDataset` for handling :class:`Graph` and :class:`Data`
data.
"""
def __new__(cls, conditions_dict, **kwargs):
@@ -28,7 +29,8 @@ class PinaDatasetFactory:
:class:`PinaGraphDataset`, otherwise returns a
:class:`PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing the conditions.
:param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance.
:return: A subclass of :class:`PinaDataset`.
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
@@ -50,11 +52,11 @@ class PinaDatasetFactory:
@staticmethod
def _is_graph_dataset(conditions_dict):
"""
Check if a graph is present in the conditions.
Check if a graph is present in the conditions (at least one time).
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise
:return: True if a graph is present in the conditions, False otherwise.
:rtype: bool
"""
@@ -68,25 +70,28 @@ class PinaDatasetFactory:
return False
class PinaDataset(Dataset):
class PinaDataset(Dataset, ABC):
"""
Abstract class for the PINA dataset
Abstract class for the PINA dataset. It defines the common interface for
the :class:`PinaTensorDataset` and :class:`PinaGraphDataset` classes.
"""
def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching
):
"""
Initialize the :class:`PinaDataset`.
Initialize a :class:`PinaDataset` instance by storing the provided
conditions dictionary, the maximum number of conditions to consider,
and the automatic batching flag.
Stores the conditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag.
:param dict conditions_dict: Dictionary containing the conditions.
:param dict max_conditions_lengths: Maximum number of data points to
consider in a single batch for each condition.
:param bool automatic_batching: Whether PyTorch automatic batching is
enabled in :class:`PinaDataModule`.
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:param max_conditions_lengths: Specifies the maximum number of data
points to include in a single batch for each condition.
:type max_conditions_lengths: dict
:param automatic_batching: Indicates whether PyTorch automatic batching
is enabled in :class:`PinaDataModule`.
:type automatic_batching: bool
"""
# Store the conditions dictionary
@@ -107,9 +112,9 @@ class PinaDataset(Dataset):
def _get_max_len(self):
"""
Returns the length of the longest condition in the dataset
Returns the length of the longest condition in the dataset.
:return: Length of the longest condition in the dataset
:return: Length of the longest condition in the dataset.
:rtype: int
"""
@@ -129,9 +134,9 @@ class PinaDataset(Dataset):
Return the index itself. This is used when automatic batching is
disabled to postpone the data retrieval to the dataloader.
:param idx: Index
:param idx: Index.
:type idx: int
:return: Index
:return: Index.
:rtype: int
"""
@@ -143,8 +148,8 @@ class PinaDataset(Dataset):
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index
:return: A dictionary containing the data at the given index
:param int idx: Index.
:return: A dictionary containing the data at the given index.
:rtype: dict
"""
@@ -156,23 +161,25 @@ class PinaDataset(Dataset):
def get_all_data(self):
"""
Return all data in the dataset
Return all data in the dataset.
:return: All data in the dataset
:return: A dictionary containing all the data in the dataset.
:rtype: dict
"""
index = list(range(len(self)))
return self.fetch_from_idx_list(index)
def fetch_from_idx_list(self, idx):
"""
Return data from the dataset given a list of indices
Return data from the dataset given a list of indices.
:param idx: List of indices
:param idx: List of indices.
:type idx: list
:return: Data from the dataset
:return: A dictionary containing the data at the given indices.
:rtype: dict
"""
to_return_dict = {}
for condition, data in self.conditions_dict.items():
# Get the indices for the current condition
@@ -190,30 +197,27 @@ class PinaDataset(Dataset):
@abstractmethod
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param dict data: Dictionary containing the data
:param list idx_list: List of indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
Abstract method to retrieve data from the dataset given a list of
indices.
"""
class PinaTensorDataset(PinaDataset):
"""
Class for the PINA dataset with torch.Tensor data
Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`LabelTensor` data.
"""
# Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
Retrieve data from the dataset given a list of indices.
:param data: Dictionary containing the data
(only torch.Tensor/LableTensor)
(only torch.Tensor/LableTensor).
:type data: dict
:param list(int) idx_list: indices to retrieve
:return: Dictionary containing the data at the given indices
:param list(int) idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
"""
@@ -222,9 +226,9 @@ class PinaTensorDataset(PinaDataset):
@property
def input(self):
"""
Method to return all input points from the dataset.
Return the input data for the dataset.
:return: Dictionary containing the input points
:return: Dictionary containing the input points.
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}
@@ -232,15 +236,17 @@ class PinaTensorDataset(PinaDataset):
class PinaGraphDataset(PinaDataset):
"""
Class for the PINA dataset with torch_geometric.data.Data data
Dataset class for the PINA dataset with :class:`torch_geometric.data.Data`
and :class:`Graph` data.
"""
def _create_graph_batch(self, data):
"""
Create a LabelBatch object from a list of Data objects.
Create a LabelBatch object from a list of
:class:`torch_geometric.data.Data` objects.
:param data: List of Data or Graph objects
:type data: list(Data) | list(Graph)
:param data: List of items to collate in a single batch.
:type data: list(torch_geometric.data.Data) | list(Graph)
:return: LabelBatch object all the graph collated in a single batch
disconnected graphs.
:rtype: LabelBatch
@@ -255,7 +261,7 @@ class PinaGraphDataset(PinaDataset):
:param data: torch.Tensor object of shape (N, ...) where N is the
number of data points.
:type data: torch.Tensor | LabelTensor
:return: reshaped torch.Tensor or LabelTensor object
:return: reshaped torch.Tensor or LabelTensor object.
:rtype: torch.Tensor | LabelTensor
"""
out = data.reshape(-1, *data.shape[2:])
@@ -263,12 +269,13 @@ class PinaGraphDataset(PinaDataset):
def create_batch(self, data):
"""
Create a Batch object from a list of Data objects.
Create a Batch object from a list of :class:`torch_geometric.data.Data`
objects.
:param data: List of Data objects
:param data: List of items to collate in a single batch.
:type data: list
:return: Batch object
:rtype: Batch or PinaBatch
:return: Batch object.
:rtype: Batch | PinaBatch
"""
if isinstance(data[0], Data):
@@ -278,13 +285,14 @@ class PinaGraphDataset(PinaDataset):
# Override _retrive_data method for graph handling
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
Retrieve data from the dataset given a list of indices.
:param dict data: dictionary containing the data
:param list idx_list: list of indices to retrieve
:return: dictionary containing the data at the given indices
:param dict data: Dictionary containing the data.
:param list idx_list: List of indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
"""
# Return the data from the current condition
# If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor