Doc data
This commit is contained in:
committed by
Nicola Demo
parent
99d2e70f4a
commit
cbbaa4062f
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Import data classes
|
Module for data data module and dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ["PinaDataModule", "PinaDataset"]
|
__all__ = ["PinaDataModule", "PinaDataset"]
|
||||||
|
|||||||
@@ -16,16 +16,16 @@ from ..collector import Collector
|
|||||||
|
|
||||||
|
|
||||||
class DummyDataloader:
|
class DummyDataloader:
|
||||||
""" "
|
"""
|
||||||
Dummy dataloader used when batch size is None. It callects all the data
|
Dataloader used when batch size is ``None``. It returns the entire dataset
|
||||||
in self.dataset and returns it when it is called a single batch.
|
in a single batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
"""
|
"""
|
||||||
Preprare a dataloader object which will return the entire dataset
|
Preprare a dataloader object which will return the entire dataset
|
||||||
in a single batch. Depending on the number of GPUs, the dataset we
|
in a single batch. Depending on the number of GPUs, the dataset is
|
||||||
have the following cases:
|
managed as follows:
|
||||||
|
|
||||||
- **Distributed Environment** (multiple GPUs):
|
- **Distributed Environment** (multiple GPUs):
|
||||||
- Divides the dataset across processes using the rank and world
|
- Divides the dataset across processes using the rank and world
|
||||||
@@ -38,7 +38,7 @@ class DummyDataloader:
|
|||||||
:param dataset: The dataset object to be processed.
|
:param dataset: The dataset object to be processed.
|
||||||
:type dataset: PinaDataset
|
:type dataset: PinaDataset
|
||||||
|
|
||||||
.. note:: This data loader is used when the batch size is None.
|
.. note:: This data loader is used when the batch size is ``None``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -72,7 +72,9 @@ class DummyDataloader:
|
|||||||
|
|
||||||
class Collator:
|
class Collator:
|
||||||
"""
|
"""
|
||||||
Class used to collate the batch
|
This callable class is used to collate the data points fetched from the
|
||||||
|
dataset. The collation is performed based on the type of dataset used and
|
||||||
|
on the batching strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -121,7 +123,13 @@ class Collator:
|
|||||||
def _collate_torch_dataloader(self, batch):
|
def _collate_torch_dataloader(self, batch):
|
||||||
"""
|
"""
|
||||||
Function used to collate the batch
|
Function used to collate the batch
|
||||||
|
|
||||||
|
:param list(dict) batch: List of retrieved data.
|
||||||
|
:return: Dictionary containing the data points fetched from the dataset,
|
||||||
|
collated.
|
||||||
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_dict = {}
|
batch_dict = {}
|
||||||
if isinstance(batch, dict):
|
if isinstance(batch, dict):
|
||||||
return batch
|
return batch
|
||||||
@@ -149,15 +157,15 @@ class Collator:
|
|||||||
def _collate_tensor_dataset(data_list):
|
def _collate_tensor_dataset(data_list):
|
||||||
"""
|
"""
|
||||||
Function used to collate the data when the dataset is a
|
Function used to collate the data when the dataset is a
|
||||||
`PinaTensorDataset`.
|
:class:`PinaTensorDataset`.
|
||||||
|
|
||||||
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
|
:param data_list: Elements to be collated.
|
||||||
collated.
|
|
||||||
:type data_list: list(torch.Tensor) | list(LabelTensor)
|
:type data_list: list(torch.Tensor) | list(LabelTensor)
|
||||||
:raises RuntimeError: If the data is not a `torch.Tensor` or a
|
:return: Batch of data.
|
||||||
`LabelTensor`.
|
|
||||||
:return: Batch of data
|
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
|
|
||||||
|
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
||||||
|
:class:`LabelTensor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
@@ -169,13 +177,15 @@ class Collator:
|
|||||||
def _collate_graph_dataset(self, data_list):
|
def _collate_graph_dataset(self, data_list):
|
||||||
"""
|
"""
|
||||||
Function used to collate the data when the dataset is a
|
Function used to collate the data when the dataset is a
|
||||||
`PinaGraphDataset`.
|
:class:`PinaGraphDataset`.
|
||||||
|
|
||||||
:param data_list: List of `Data` or `Graph` to be collated.
|
:param data_list: Elememts to be collated.
|
||||||
:type data_list: list(Data) | list(Graph)
|
:type data_list: list(torch_geometric.data.Data) | list(Graph)
|
||||||
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
|
:return: Batch of data.
|
||||||
:return: Batch of data
|
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
|
|
||||||
|
:raises RuntimeError: If the data is not a
|
||||||
|
:class:`torch_geometric.data.Data` or a :class:`Graph`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
@@ -184,13 +194,18 @@ class Collator:
|
|||||||
return torch.cat(data_list)
|
return torch.cat(data_list)
|
||||||
if isinstance(data_list[0], Data):
|
if isinstance(data_list[0], Data):
|
||||||
return self.dataset.create_batch(data_list)
|
return self.dataset.create_batch(data_list)
|
||||||
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
|
raise RuntimeError(
|
||||||
|
"Data must be Tensors or LabelTensor or pyG "
|
||||||
|
"torch_geometric.data.Data"
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
"""
|
"""
|
||||||
Call the function to collate the batch, defined in __init__.
|
Perform the collation of the data points fetched from the dataset.
|
||||||
|
The behavoior of the function is set based on the batching strategy
|
||||||
|
during class initialization.
|
||||||
|
|
||||||
:param batch: list of indices or list of retrieved data
|
:param batch: List of retrieved data or sampled indices.
|
||||||
:type batch: list(int) | list(dict)
|
:type batch: list(int) | list(dict)
|
||||||
:return: Dictionary containing the data points fetched from the dataset,
|
:return: Dictionary containing the data points fetched from the dataset,
|
||||||
collated.
|
collated.
|
||||||
@@ -202,13 +217,14 @@ class Collator:
|
|||||||
|
|
||||||
class PinaSampler:
|
class PinaSampler:
|
||||||
"""
|
"""
|
||||||
Class used to create the sampler instance.
|
This class is used to create the sampler instance based on the shuffle
|
||||||
|
parameter and the environment in which the code is running.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, dataset, shuffle):
|
def __new__(cls, dataset, shuffle):
|
||||||
"""
|
"""
|
||||||
Create the sampler instance, according to shuffle and whether the
|
Instantiate the sampler based on the environment in which the code is
|
||||||
environment is distributed or not.
|
running.
|
||||||
|
|
||||||
:param PinaDataset dataset: The dataset to be sampled.
|
:param PinaDataset dataset: The dataset to be sampled.
|
||||||
:param bool shuffle: whether to shuffle the dataset or not before
|
:param bool shuffle: whether to shuffle the dataset or not before
|
||||||
@@ -232,8 +248,9 @@ class PinaSampler:
|
|||||||
|
|
||||||
class PinaDataModule(LightningDataModule):
|
class PinaDataModule(LightningDataModule):
|
||||||
"""
|
"""
|
||||||
This class extend LightningDataModule, allowing proper creation and
|
This class extends :class:`pytorch_lightning.LightningDataModule`,
|
||||||
management of different types of Datasets defined in PINA
|
allowing proper creation and management of different types of datasets
|
||||||
|
defined in PINA.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -253,23 +270,31 @@ class PinaDataModule(LightningDataModule):
|
|||||||
Initialize the object, creating datasets based on the input problem.
|
Initialize the object, creating datasets based on the input problem.
|
||||||
|
|
||||||
:param AbstractProblem problem: The problem containing the data on which
|
:param AbstractProblem problem: The problem containing the data on which
|
||||||
to train/test the model.
|
to create the datasets and dataloaders.
|
||||||
:param float train_size: Fraction or number of elements in the training
|
:param float train_size: Fraction or number of elements in the training
|
||||||
split.
|
split. It must be in the range [0, 1].
|
||||||
:param float test_size: Fraction or number of elements in the test
|
:param float test_size: Fraction or number of elements in the test
|
||||||
split.
|
split. It must be in the range [0, 1].
|
||||||
:param float val_size: Fraction or number of elements in the validation
|
:param float val_size: Fraction or number of elements in the validation
|
||||||
split.
|
split. It must be in the range [0, 1].
|
||||||
:param batch_size: The batch size used for training. If `None`, the
|
:param batch_size: The batch size used for training. If `None`, the
|
||||||
entire dataset is used per batch.
|
entire dataset is used per batch.
|
||||||
:type batch_size: int | None
|
:type batch_size: int | None
|
||||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||||
|
Default True.
|
||||||
:param bool repeat: Whether to repeat the dataset indefinitely.
|
:param bool repeat: Whether to repeat the dataset indefinitely.
|
||||||
|
Default False.
|
||||||
:param automatic_batching: Whether to enable automatic batching.
|
:param automatic_batching: Whether to enable automatic batching.
|
||||||
|
Default False.
|
||||||
:param int num_workers: Number of worker threads for data loading.
|
:param int num_workers: Number of worker threads for data loading.
|
||||||
Default 0 (serial loading)
|
Default 0 (serial loading). For more information, see
|
||||||
|
https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
transfer to GPU. (Default False)
|
transfer to GPU. (Default False). For more information, see
|
||||||
|
https://pytorch.org/docs/stable/data.html#memory-pinning
|
||||||
|
|
||||||
|
:raises ValueError: If at least one of the splits is negative.
|
||||||
|
:raises ValueError: If the sum of the splits is different from 1.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -278,6 +303,8 @@ class PinaDataModule(LightningDataModule):
|
|||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.repeat = repeat
|
self.repeat = repeat
|
||||||
self.automatic_batching = automatic_batching
|
self.automatic_batching = automatic_batching
|
||||||
|
|
||||||
|
# If batch size is None, num_workers has no effect
|
||||||
if batch_size is None and num_workers != 0:
|
if batch_size is None and num_workers != 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Setting num_workers when batch_size is None has no effect on "
|
"Setting num_workers when batch_size is None has no effect on "
|
||||||
@@ -286,6 +313,8 @@ class PinaDataModule(LightningDataModule):
|
|||||||
self.num_workers = 0
|
self.num_workers = 0
|
||||||
else:
|
else:
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
||||||
|
# If batch size is None, pin_memory has no effect
|
||||||
if batch_size is None and pin_memory:
|
if batch_size is None and pin_memory:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Setting pin_memory to True has no effect when "
|
"Setting pin_memory to True has no effect when "
|
||||||
@@ -309,16 +338,22 @@ class PinaDataModule(LightningDataModule):
|
|||||||
splits_dict["train"] = train_size
|
splits_dict["train"] = train_size
|
||||||
self.train_dataset = None
|
self.train_dataset = None
|
||||||
else:
|
else:
|
||||||
|
# Use the super method to create the train dataloader which
|
||||||
|
# raises NotImplementedError
|
||||||
self.train_dataloader = super().train_dataloader
|
self.train_dataloader = super().train_dataloader
|
||||||
if test_size > 0:
|
if test_size > 0:
|
||||||
splits_dict["test"] = test_size
|
splits_dict["test"] = test_size
|
||||||
self.test_dataset = None
|
self.test_dataset = None
|
||||||
else:
|
else:
|
||||||
|
# Use the super method to create the train dataloader which
|
||||||
|
# raises NotImplementedError
|
||||||
self.test_dataloader = super().test_dataloader
|
self.test_dataloader = super().test_dataloader
|
||||||
if val_size > 0:
|
if val_size > 0:
|
||||||
splits_dict["val"] = val_size
|
splits_dict["val"] = val_size
|
||||||
self.val_dataset = None
|
self.val_dataset = None
|
||||||
else:
|
else:
|
||||||
|
# Use the super method to create the train dataloader which
|
||||||
|
# raises NotImplementedError
|
||||||
self.val_dataloader = super().val_dataloader
|
self.val_dataloader = super().val_dataloader
|
||||||
|
|
||||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||||
@@ -326,7 +361,13 @@ class PinaDataModule(LightningDataModule):
|
|||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
"""
|
"""
|
||||||
Perform the splitting of the dataset
|
Create the dataset objects for the given stage.
|
||||||
|
If the stage is "fit", the training and validation datasets are created.
|
||||||
|
If the stage is "test", the testing dataset is created.
|
||||||
|
|
||||||
|
:param str stage: The stage for which to perform the splitting.
|
||||||
|
|
||||||
|
:raises ValueError: If the stage is neither "fit" nor "test".
|
||||||
"""
|
"""
|
||||||
if stage == "fit" or stage is None:
|
if stage == "fit" or stage is None:
|
||||||
self.train_dataset = PinaDatasetFactory(
|
self.train_dataset = PinaDatasetFactory(
|
||||||
@@ -354,8 +395,18 @@ class PinaDataModule(LightningDataModule):
|
|||||||
raise ValueError("stage must be either 'fit' or 'test'.")
|
raise ValueError("stage must be either 'fit' or 'test'.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _split_condition(condition_dict, splits_dict):
|
def _split_condition(single_condition_dict, splits_dict):
|
||||||
len_condition = len(condition_dict["input"])
|
"""
|
||||||
|
Split the condition into different stages.
|
||||||
|
|
||||||
|
:param dict single_condition_dict: The condition to be split.
|
||||||
|
:param dict splits_dict: The dictionary containing the number of
|
||||||
|
elements in each stage.
|
||||||
|
:return: A dictionary containing the split condition.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
len_condition = len(single_condition_dict["input"])
|
||||||
|
|
||||||
lengths = [
|
lengths = [
|
||||||
int(len_condition * length) for length in splits_dict.values()
|
int(len_condition * length) for length in splits_dict.values()
|
||||||
@@ -374,7 +425,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
for stage, stage_len in splits_dict.items():
|
for stage, stage_len in splits_dict.items():
|
||||||
to_return_dict[stage] = {
|
to_return_dict[stage] = {
|
||||||
k: v[offset : offset + stage_len]
|
k: v[offset : offset + stage_len]
|
||||||
for k, v in condition_dict.items()
|
for k, v in single_condition_dict.items()
|
||||||
if k != "equation"
|
if k != "equation"
|
||||||
# Equations are NEVER dataloaded
|
# Equations are NEVER dataloaded
|
||||||
}
|
}
|
||||||
@@ -386,7 +437,13 @@ class PinaDataModule(LightningDataModule):
|
|||||||
|
|
||||||
def _create_splits(self, collector, splits_dict):
|
def _create_splits(self, collector, splits_dict):
|
||||||
"""
|
"""
|
||||||
Create the dataset objects putting data
|
Create the dataset objects putting data in the correct splits.
|
||||||
|
|
||||||
|
:param Collector collector: The collector object containing the data.
|
||||||
|
:param dict splits_dict: The dictionary containing the number of
|
||||||
|
elements in each stage.
|
||||||
|
:return: The dictionary containing the dataset objects.
|
||||||
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ----------- Auxiliary function ------------
|
# ----------- Auxiliary function ------------
|
||||||
@@ -422,6 +479,15 @@ class PinaDataModule(LightningDataModule):
|
|||||||
return dataset_dict
|
return dataset_dict
|
||||||
|
|
||||||
def _create_dataloader(self, split, dataset):
|
def _create_dataloader(self, split, dataset):
|
||||||
|
""" "
|
||||||
|
Create the dataloader for the given split.
|
||||||
|
|
||||||
|
:param str split: The split on which to create the dataloader.
|
||||||
|
:param str dataset: The dataset to be used for the dataloader.
|
||||||
|
:return: The dataloader for the given split.
|
||||||
|
:rtype: torch.utils.data.DataLoader
|
||||||
|
"""
|
||||||
|
|
||||||
shuffle = self.shuffle if split == "train" else False
|
shuffle = self.shuffle if split == "train" else False
|
||||||
# Suppress the warning about num_workers.
|
# Suppress the warning about num_workers.
|
||||||
# In many cases, especially for PINNs,
|
# In many cases, especially for PINNs,
|
||||||
@@ -470,6 +536,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:return: The maximum length of the conditions.
|
:return: The maximum length of the conditions.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_conditions_lengths = {}
|
max_conditions_lengths = {}
|
||||||
for k, v in self.collector_splits[split].items():
|
for k, v in self.collector_splits[split].items():
|
||||||
if self.batch_size is None:
|
if self.batch_size is None:
|
||||||
@@ -484,7 +551,10 @@ class PinaDataModule(LightningDataModule):
|
|||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
"""
|
"""
|
||||||
Create the validation dataloader
|
Create the validation dataloader.
|
||||||
|
|
||||||
|
:return: The validation dataloader
|
||||||
|
:rtype: DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("val", self.val_dataset)
|
return self._create_dataloader("val", self.val_dataset)
|
||||||
|
|
||||||
@@ -509,20 +579,17 @@ class PinaDataModule(LightningDataModule):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||||
"""
|
"""
|
||||||
Transfer the batch to the device. This method is called in the
|
Transfer the batch to the device. This method is used when the batch
|
||||||
training loop and is used to transfer the batch to the device.
|
size is None: batch has already been transferred to the device.
|
||||||
This method is used when the batch size is None: batch has already
|
|
||||||
been transferred to the device.
|
|
||||||
|
|
||||||
:param list(tuple) batch: list of tuple where the first element of the
|
:param list(tuple) batch: list of tuple where the first element of the
|
||||||
tuple is the condition name and the second element is the data.
|
tuple is the condition name and the second element is the data.
|
||||||
:param device: device to which the batch is transferred
|
:param torch.device device: device to which the batch is transferred.
|
||||||
:type device: torch.device
|
:param int dataloader_idx: index of the dataloader.
|
||||||
:param dataloader_idx: index of the dataloader
|
|
||||||
:type dataloader_idx: int
|
|
||||||
:return: The batch transferred to the device.
|
:return: The batch transferred to the device.
|
||||||
:rtype: list(tuple)
|
:rtype: list(tuple)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||||
@@ -531,12 +598,13 @@ class PinaDataModule(LightningDataModule):
|
|||||||
training loop and is used to transfer the batch to the device.
|
training loop and is used to transfer the batch to the device.
|
||||||
|
|
||||||
:param dict batch: The batch to be transferred to the device.
|
:param dict batch: The batch to be transferred to the device.
|
||||||
:param device: The device to which the batch is transferred.
|
:param torch.device device: The device to which the batch is
|
||||||
:type device: torch.device
|
transferred.
|
||||||
:param int dataloader_idx: The index of the dataloader.
|
:param int dataloader_idx: The index of the dataloader.
|
||||||
:return: The batch transferred to the device.
|
:return: The batch transferred to the device.
|
||||||
:rtype: list(tuple)
|
:rtype: list(tuple)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch = [
|
batch = [
|
||||||
(
|
(
|
||||||
k,
|
k,
|
||||||
@@ -552,8 +620,18 @@ class PinaDataModule(LightningDataModule):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_slit_sizes(train_size, test_size, val_size):
|
def _check_slit_sizes(train_size, test_size, val_size):
|
||||||
"""
|
"""
|
||||||
Check if the splits are correct
|
Check if the splits are correct. The splits sizes must be positive and
|
||||||
|
the sum of the splits must be 1.
|
||||||
|
|
||||||
|
:param float train_size: The size of the training split.
|
||||||
|
:param float test_size: The size of the testing split.
|
||||||
|
:param float val_size: The size of the validation split.
|
||||||
|
|
||||||
|
:raises ValueError: If at least one of the splits is negative.
|
||||||
|
:raises ValueError: If the sum of the splits is different
|
||||||
|
from 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if train_size < 0 or test_size < 0 or val_size < 0:
|
if train_size < 0 or test_size < 0 or val_size < 0:
|
||||||
raise ValueError("The splits must be positive")
|
raise ValueError("The splits must be positive")
|
||||||
if abs(train_size + test_size + val_size - 1) > 1e-6:
|
if abs(train_size + test_size + val_size - 1) > 1e-6:
|
||||||
@@ -567,6 +645,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:return: The input points for training.
|
:return: The input points for training.
|
||||||
:rtype dict
|
:rtype dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
to_return = {}
|
to_return = {}
|
||||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||||
to_return["train"] = self.train_dataset.input
|
to_return["train"] = self.train_dataset.input
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
This module provide basic data management functionalities
|
Module for the PINA dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod, ABC
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from ..graph import Graph, LabelBatch
|
from ..graph import Graph, LabelBatch
|
||||||
@@ -15,9 +15,10 @@ class PinaDatasetFactory:
|
|||||||
Depending on the type inside the conditions, it creates a different dataset
|
Depending on the type inside the conditions, it creates a different dataset
|
||||||
object:
|
object:
|
||||||
|
|
||||||
- :class:`PinaTensorDataset` for `torch.Tensor`
|
- :class:`PinaTensorDataset` for handling :class:`torch.Tensor` and
|
||||||
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
|
:class:`LabelTensor` data.
|
||||||
objects
|
- :class:`PinaGraphDataset` for handling :class:`Graph` and :class:`Data`
|
||||||
|
data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, conditions_dict, **kwargs):
|
def __new__(cls, conditions_dict, **kwargs):
|
||||||
@@ -28,7 +29,8 @@ class PinaDatasetFactory:
|
|||||||
:class:`PinaGraphDataset`, otherwise returns a
|
:class:`PinaGraphDataset`, otherwise returns a
|
||||||
:class:`PinaTensorDataset`.
|
:class:`PinaTensorDataset`.
|
||||||
|
|
||||||
:param dict conditions_dict: Dictionary containing the conditions.
|
:param dict conditions_dict: Dictionary containing all the conditions
|
||||||
|
to be included in the dataset instance.
|
||||||
:return: A subclass of :class:`PinaDataset`.
|
:return: A subclass of :class:`PinaDataset`.
|
||||||
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
|
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
|
||||||
|
|
||||||
@@ -50,11 +52,11 @@ class PinaDatasetFactory:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_graph_dataset(conditions_dict):
|
def _is_graph_dataset(conditions_dict):
|
||||||
"""
|
"""
|
||||||
Check if a graph is present in the conditions.
|
Check if a graph is present in the conditions (at least one time).
|
||||||
|
|
||||||
:param conditions_dict: Dictionary containing the conditions.
|
:param conditions_dict: Dictionary containing the conditions.
|
||||||
:type conditions_dict: dict
|
:type conditions_dict: dict
|
||||||
:return: True if a graph is present in the conditions, False otherwise
|
:return: True if a graph is present in the conditions, False otherwise.
|
||||||
:rtype: bool
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -68,25 +70,28 @@ class PinaDatasetFactory:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class PinaDataset(Dataset):
|
class PinaDataset(Dataset, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for the PINA dataset
|
Abstract class for the PINA dataset. It defines the common interface for
|
||||||
|
the :class:`PinaTensorDataset` and :class:`PinaGraphDataset` classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the :class:`PinaDataset`.
|
Initialize a :class:`PinaDataset` instance by storing the provided
|
||||||
|
conditions dictionary, the maximum number of conditions to consider,
|
||||||
|
and the automatic batching flag.
|
||||||
|
|
||||||
Stores the conditions dictionary, the maximum number of conditions to
|
:param conditions_dict: Dictionary containing the conditions.
|
||||||
consider, and the automatic batching flag.
|
:type conditions_dict: dict
|
||||||
|
:param max_conditions_lengths: Specifies the maximum number of data
|
||||||
:param dict conditions_dict: Dictionary containing the conditions.
|
points to include in a single batch for each condition.
|
||||||
:param dict max_conditions_lengths: Maximum number of data points to
|
:type max_conditions_lengths: dict
|
||||||
consider in a single batch for each condition.
|
:param automatic_batching: Indicates whether PyTorch automatic batching
|
||||||
:param bool automatic_batching: Whether PyTorch automatic batching is
|
is enabled in :class:`PinaDataModule`.
|
||||||
enabled in :class:`PinaDataModule`.
|
:type automatic_batching: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Store the conditions dictionary
|
# Store the conditions dictionary
|
||||||
@@ -107,9 +112,9 @@ class PinaDataset(Dataset):
|
|||||||
|
|
||||||
def _get_max_len(self):
|
def _get_max_len(self):
|
||||||
"""
|
"""
|
||||||
Returns the length of the longest condition in the dataset
|
Returns the length of the longest condition in the dataset.
|
||||||
|
|
||||||
:return: Length of the longest condition in the dataset
|
:return: Length of the longest condition in the dataset.
|
||||||
:rtype: int
|
:rtype: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -129,9 +134,9 @@ class PinaDataset(Dataset):
|
|||||||
Return the index itself. This is used when automatic batching is
|
Return the index itself. This is used when automatic batching is
|
||||||
disabled to postpone the data retrieval to the dataloader.
|
disabled to postpone the data retrieval to the dataloader.
|
||||||
|
|
||||||
:param idx: Index
|
:param idx: Index.
|
||||||
:type idx: int
|
:type idx: int
|
||||||
:return: Index
|
:return: Index.
|
||||||
:rtype: int
|
:rtype: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -143,8 +148,8 @@ class PinaDataset(Dataset):
|
|||||||
Return the data at the given index in the dataset. This is used when
|
Return the data at the given index in the dataset. This is used when
|
||||||
automatic batching is enabled.
|
automatic batching is enabled.
|
||||||
|
|
||||||
:param int idx: Index
|
:param int idx: Index.
|
||||||
:return: A dictionary containing the data at the given index
|
:return: A dictionary containing the data at the given index.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -156,23 +161,25 @@ class PinaDataset(Dataset):
|
|||||||
|
|
||||||
def get_all_data(self):
|
def get_all_data(self):
|
||||||
"""
|
"""
|
||||||
Return all data in the dataset
|
Return all data in the dataset.
|
||||||
|
|
||||||
:return: All data in the dataset
|
:return: A dictionary containing all the data in the dataset.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
index = list(range(len(self)))
|
index = list(range(len(self)))
|
||||||
return self.fetch_from_idx_list(index)
|
return self.fetch_from_idx_list(index)
|
||||||
|
|
||||||
def fetch_from_idx_list(self, idx):
|
def fetch_from_idx_list(self, idx):
|
||||||
"""
|
"""
|
||||||
Return data from the dataset given a list of indices
|
Return data from the dataset given a list of indices.
|
||||||
|
|
||||||
:param idx: List of indices
|
:param idx: List of indices.
|
||||||
:type idx: list
|
:type idx: list
|
||||||
:return: Data from the dataset
|
:return: A dictionary containing the data at the given indices.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
to_return_dict = {}
|
to_return_dict = {}
|
||||||
for condition, data in self.conditions_dict.items():
|
for condition, data in self.conditions_dict.items():
|
||||||
# Get the indices for the current condition
|
# Get the indices for the current condition
|
||||||
@@ -190,30 +197,27 @@ class PinaDataset(Dataset):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _retrive_data(self, data, idx_list):
|
def _retrive_data(self, data, idx_list):
|
||||||
"""
|
"""
|
||||||
Retrieve data from the dataset given a list of indices
|
Abstract method to retrieve data from the dataset given a list of
|
||||||
|
indices.
|
||||||
:param dict data: Dictionary containing the data
|
|
||||||
:param list idx_list: List of indices to retrieve
|
|
||||||
:return: Dictionary containing the data at the given indices
|
|
||||||
:rtype: dict
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class PinaTensorDataset(PinaDataset):
|
class PinaTensorDataset(PinaDataset):
|
||||||
"""
|
"""
|
||||||
Class for the PINA dataset with torch.Tensor data
|
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
||||||
|
:class:`LabelTensor` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Override _retrive_data method for torch.Tensor data
|
# Override _retrive_data method for torch.Tensor data
|
||||||
def _retrive_data(self, data, idx_list):
|
def _retrive_data(self, data, idx_list):
|
||||||
"""
|
"""
|
||||||
Retrieve data from the dataset given a list of indices
|
Retrieve data from the dataset given a list of indices.
|
||||||
|
|
||||||
:param data: Dictionary containing the data
|
:param data: Dictionary containing the data
|
||||||
(only torch.Tensor/LableTensor)
|
(only torch.Tensor/LableTensor).
|
||||||
:type data: dict
|
:type data: dict
|
||||||
:param list(int) idx_list: indices to retrieve
|
:param list(int) idx_list: indices to retrieve.
|
||||||
:return: Dictionary containing the data at the given indices
|
:return: Dictionary containing the data at the given indices.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -222,9 +226,9 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
@property
|
@property
|
||||||
def input(self):
|
def input(self):
|
||||||
"""
|
"""
|
||||||
Method to return all input points from the dataset.
|
Return the input data for the dataset.
|
||||||
|
|
||||||
:return: Dictionary containing the input points
|
:return: Dictionary containing the input points.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||||
@@ -232,15 +236,17 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
|
|
||||||
class PinaGraphDataset(PinaDataset):
|
class PinaGraphDataset(PinaDataset):
|
||||||
"""
|
"""
|
||||||
Class for the PINA dataset with torch_geometric.data.Data data
|
Dataset class for the PINA dataset with :class:`torch_geometric.data.Data`
|
||||||
|
and :class:`Graph` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _create_graph_batch(self, data):
|
def _create_graph_batch(self, data):
|
||||||
"""
|
"""
|
||||||
Create a LabelBatch object from a list of Data objects.
|
Create a LabelBatch object from a list of
|
||||||
|
:class:`torch_geometric.data.Data` objects.
|
||||||
|
|
||||||
:param data: List of Data or Graph objects
|
:param data: List of items to collate in a single batch.
|
||||||
:type data: list(Data) | list(Graph)
|
:type data: list(torch_geometric.data.Data) | list(Graph)
|
||||||
:return: LabelBatch object all the graph collated in a single batch
|
:return: LabelBatch object all the graph collated in a single batch
|
||||||
disconnected graphs.
|
disconnected graphs.
|
||||||
:rtype: LabelBatch
|
:rtype: LabelBatch
|
||||||
@@ -255,7 +261,7 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
:param data: torch.Tensor object of shape (N, ...) where N is the
|
:param data: torch.Tensor object of shape (N, ...) where N is the
|
||||||
number of data points.
|
number of data points.
|
||||||
:type data: torch.Tensor | LabelTensor
|
:type data: torch.Tensor | LabelTensor
|
||||||
:return: reshaped torch.Tensor or LabelTensor object
|
:return: reshaped torch.Tensor or LabelTensor object.
|
||||||
:rtype: torch.Tensor | LabelTensor
|
:rtype: torch.Tensor | LabelTensor
|
||||||
"""
|
"""
|
||||||
out = data.reshape(-1, *data.shape[2:])
|
out = data.reshape(-1, *data.shape[2:])
|
||||||
@@ -263,12 +269,13 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
|
|
||||||
def create_batch(self, data):
|
def create_batch(self, data):
|
||||||
"""
|
"""
|
||||||
Create a Batch object from a list of Data objects.
|
Create a Batch object from a list of :class:`torch_geometric.data.Data`
|
||||||
|
objects.
|
||||||
|
|
||||||
:param data: List of Data objects
|
:param data: List of items to collate in a single batch.
|
||||||
:type data: list
|
:type data: list
|
||||||
:return: Batch object
|
:return: Batch object.
|
||||||
:rtype: Batch or PinaBatch
|
:rtype: Batch | PinaBatch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data[0], Data):
|
if isinstance(data[0], Data):
|
||||||
@@ -278,13 +285,14 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
# Override _retrive_data method for graph handling
|
# Override _retrive_data method for graph handling
|
||||||
def _retrive_data(self, data, idx_list):
|
def _retrive_data(self, data, idx_list):
|
||||||
"""
|
"""
|
||||||
Retrieve data from the dataset given a list of indices
|
Retrieve data from the dataset given a list of indices.
|
||||||
|
|
||||||
:param dict data: dictionary containing the data
|
:param dict data: Dictionary containing the data.
|
||||||
:param list idx_list: list of indices to retrieve
|
:param list idx_list: List of indices to retrieve.
|
||||||
:return: dictionary containing the data at the given indices
|
:return: Dictionary containing the data at the given indices.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Return the data from the current condition
|
# Return the data from the current condition
|
||||||
# If the data is a list of Data objects, create a Batch object
|
# If the data is a list of Data objects, create a Batch object
|
||||||
# If the data is a list of torch.Tensor objects, create a torch.Tensor
|
# If the data is a list of torch.Tensor objects, create a torch.Tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user