Doc data
This commit is contained in:
committed by
Nicola Demo
parent
99d2e70f4a
commit
cbbaa4062f
@@ -16,16 +16,16 @@ from ..collector import Collector
|
||||
|
||||
|
||||
class DummyDataloader:
|
||||
""" "
|
||||
Dummy dataloader used when batch size is None. It callects all the data
|
||||
in self.dataset and returns it when it is called a single batch.
|
||||
"""
|
||||
Dataloader used when batch size is ``None``. It returns the entire dataset
|
||||
in a single batch.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Preprare a dataloader object which will return the entire dataset
|
||||
in a single batch. Depending on the number of GPUs, the dataset we
|
||||
have the following cases:
|
||||
in a single batch. Depending on the number of GPUs, the dataset is
|
||||
managed as follows:
|
||||
|
||||
- **Distributed Environment** (multiple GPUs):
|
||||
- Divides the dataset across processes using the rank and world
|
||||
@@ -38,7 +38,7 @@ class DummyDataloader:
|
||||
:param dataset: The dataset object to be processed.
|
||||
:type dataset: PinaDataset
|
||||
|
||||
.. note:: This data loader is used when the batch size is None.
|
||||
.. note:: This data loader is used when the batch size is ``None``.
|
||||
"""
|
||||
|
||||
if (
|
||||
@@ -72,7 +72,9 @@ class DummyDataloader:
|
||||
|
||||
class Collator:
|
||||
"""
|
||||
Class used to collate the batch
|
||||
This callable class is used to collate the data points fetched from the
|
||||
dataset. The collation is performed based on the type of dataset used and
|
||||
on the batching strategy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -121,7 +123,13 @@ class Collator:
|
||||
def _collate_torch_dataloader(self, batch):
|
||||
"""
|
||||
Function used to collate the batch
|
||||
|
||||
:param list(dict) batch: List of retrieved data.
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
batch_dict = {}
|
||||
if isinstance(batch, dict):
|
||||
return batch
|
||||
@@ -149,15 +157,15 @@ class Collator:
|
||||
def _collate_tensor_dataset(data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
`PinaTensorDataset`.
|
||||
:class:`PinaTensorDataset`.
|
||||
|
||||
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
|
||||
collated.
|
||||
:param data_list: Elements to be collated.
|
||||
:type data_list: list(torch.Tensor) | list(LabelTensor)
|
||||
:raises RuntimeError: If the data is not a `torch.Tensor` or a
|
||||
`LabelTensor`.
|
||||
:return: Batch of data
|
||||
:return: Batch of data.
|
||||
:rtype: dict
|
||||
|
||||
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
||||
:class:`LabelTensor`.
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
@@ -169,13 +177,15 @@ class Collator:
|
||||
def _collate_graph_dataset(self, data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
`PinaGraphDataset`.
|
||||
:class:`PinaGraphDataset`.
|
||||
|
||||
:param data_list: List of `Data` or `Graph` to be collated.
|
||||
:type data_list: list(Data) | list(Graph)
|
||||
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
|
||||
:return: Batch of data
|
||||
:param data_list: Elememts to be collated.
|
||||
:type data_list: list(torch_geometric.data.Data) | list(Graph)
|
||||
:return: Batch of data.
|
||||
:rtype: dict
|
||||
|
||||
:raises RuntimeError: If the data is not a
|
||||
:class:`torch_geometric.data.Data` or a :class:`Graph`.
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
@@ -184,13 +194,18 @@ class Collator:
|
||||
return torch.cat(data_list)
|
||||
if isinstance(data_list[0], Data):
|
||||
return self.dataset.create_batch(data_list)
|
||||
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
|
||||
raise RuntimeError(
|
||||
"Data must be Tensors or LabelTensor or pyG "
|
||||
"torch_geometric.data.Data"
|
||||
)
|
||||
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Call the function to collate the batch, defined in __init__.
|
||||
Perform the collation of the data points fetched from the dataset.
|
||||
The behavoior of the function is set based on the batching strategy
|
||||
during class initialization.
|
||||
|
||||
:param batch: list of indices or list of retrieved data
|
||||
:param batch: List of retrieved data or sampled indices.
|
||||
:type batch: list(int) | list(dict)
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
@@ -202,13 +217,14 @@ class Collator:
|
||||
|
||||
class PinaSampler:
|
||||
"""
|
||||
Class used to create the sampler instance.
|
||||
This class is used to create the sampler instance based on the shuffle
|
||||
parameter and the environment in which the code is running.
|
||||
"""
|
||||
|
||||
def __new__(cls, dataset, shuffle):
|
||||
"""
|
||||
Create the sampler instance, according to shuffle and whether the
|
||||
environment is distributed or not.
|
||||
Instantiate the sampler based on the environment in which the code is
|
||||
running.
|
||||
|
||||
:param PinaDataset dataset: The dataset to be sampled.
|
||||
:param bool shuffle: whether to shuffle the dataset or not before
|
||||
@@ -232,8 +248,9 @@ class PinaSampler:
|
||||
|
||||
class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
This class extend LightningDataModule, allowing proper creation and
|
||||
management of different types of Datasets defined in PINA
|
||||
This class extends :class:`pytorch_lightning.LightningDataModule`,
|
||||
allowing proper creation and management of different types of datasets
|
||||
defined in PINA.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -253,23 +270,31 @@ class PinaDataModule(LightningDataModule):
|
||||
Initialize the object, creating datasets based on the input problem.
|
||||
|
||||
:param AbstractProblem problem: The problem containing the data on which
|
||||
to train/test the model.
|
||||
to create the datasets and dataloaders.
|
||||
:param float train_size: Fraction or number of elements in the training
|
||||
split.
|
||||
split. It must be in the range [0, 1].
|
||||
:param float test_size: Fraction or number of elements in the test
|
||||
split.
|
||||
split. It must be in the range [0, 1].
|
||||
:param float val_size: Fraction or number of elements in the validation
|
||||
split.
|
||||
split. It must be in the range [0, 1].
|
||||
:param batch_size: The batch size used for training. If `None`, the
|
||||
entire dataset is used per batch.
|
||||
:type batch_size: int | None
|
||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||
Default True.
|
||||
:param bool repeat: Whether to repeat the dataset indefinitely.
|
||||
Default False.
|
||||
:param automatic_batching: Whether to enable automatic batching.
|
||||
Default False.
|
||||
:param int num_workers: Number of worker threads for data loading.
|
||||
Default 0 (serial loading)
|
||||
Default 0 (serial loading). For more information, see
|
||||
https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. (Default False)
|
||||
transfer to GPU. (Default False). For more information, see
|
||||
https://pytorch.org/docs/stable/data.html#memory-pinning
|
||||
|
||||
:raises ValueError: If at least one of the splits is negative.
|
||||
:raises ValueError: If the sum of the splits is different from 1.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -278,6 +303,8 @@ class PinaDataModule(LightningDataModule):
|
||||
self.shuffle = shuffle
|
||||
self.repeat = repeat
|
||||
self.automatic_batching = automatic_batching
|
||||
|
||||
# If batch size is None, num_workers has no effect
|
||||
if batch_size is None and num_workers != 0:
|
||||
warnings.warn(
|
||||
"Setting num_workers when batch_size is None has no effect on "
|
||||
@@ -286,6 +313,8 @@ class PinaDataModule(LightningDataModule):
|
||||
self.num_workers = 0
|
||||
else:
|
||||
self.num_workers = num_workers
|
||||
|
||||
# If batch size is None, pin_memory has no effect
|
||||
if batch_size is None and pin_memory:
|
||||
warnings.warn(
|
||||
"Setting pin_memory to True has no effect when "
|
||||
@@ -309,16 +338,22 @@ class PinaDataModule(LightningDataModule):
|
||||
splits_dict["train"] = train_size
|
||||
self.train_dataset = None
|
||||
else:
|
||||
# Use the super method to create the train dataloader which
|
||||
# raises NotImplementedError
|
||||
self.train_dataloader = super().train_dataloader
|
||||
if test_size > 0:
|
||||
splits_dict["test"] = test_size
|
||||
self.test_dataset = None
|
||||
else:
|
||||
# Use the super method to create the train dataloader which
|
||||
# raises NotImplementedError
|
||||
self.test_dataloader = super().test_dataloader
|
||||
if val_size > 0:
|
||||
splits_dict["val"] = val_size
|
||||
self.val_dataset = None
|
||||
else:
|
||||
# Use the super method to create the train dataloader which
|
||||
# raises NotImplementedError
|
||||
self.val_dataloader = super().val_dataloader
|
||||
|
||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||
@@ -326,7 +361,13 @@ class PinaDataModule(LightningDataModule):
|
||||
|
||||
def setup(self, stage=None):
|
||||
"""
|
||||
Perform the splitting of the dataset
|
||||
Create the dataset objects for the given stage.
|
||||
If the stage is "fit", the training and validation datasets are created.
|
||||
If the stage is "test", the testing dataset is created.
|
||||
|
||||
:param str stage: The stage for which to perform the splitting.
|
||||
|
||||
:raises ValueError: If the stage is neither "fit" nor "test".
|
||||
"""
|
||||
if stage == "fit" or stage is None:
|
||||
self.train_dataset = PinaDatasetFactory(
|
||||
@@ -354,8 +395,18 @@ class PinaDataModule(LightningDataModule):
|
||||
raise ValueError("stage must be either 'fit' or 'test'.")
|
||||
|
||||
@staticmethod
|
||||
def _split_condition(condition_dict, splits_dict):
|
||||
len_condition = len(condition_dict["input"])
|
||||
def _split_condition(single_condition_dict, splits_dict):
|
||||
"""
|
||||
Split the condition into different stages.
|
||||
|
||||
:param dict single_condition_dict: The condition to be split.
|
||||
:param dict splits_dict: The dictionary containing the number of
|
||||
elements in each stage.
|
||||
:return: A dictionary containing the split condition.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
len_condition = len(single_condition_dict["input"])
|
||||
|
||||
lengths = [
|
||||
int(len_condition * length) for length in splits_dict.values()
|
||||
@@ -374,7 +425,7 @@ class PinaDataModule(LightningDataModule):
|
||||
for stage, stage_len in splits_dict.items():
|
||||
to_return_dict[stage] = {
|
||||
k: v[offset : offset + stage_len]
|
||||
for k, v in condition_dict.items()
|
||||
for k, v in single_condition_dict.items()
|
||||
if k != "equation"
|
||||
# Equations are NEVER dataloaded
|
||||
}
|
||||
@@ -386,7 +437,13 @@ class PinaDataModule(LightningDataModule):
|
||||
|
||||
def _create_splits(self, collector, splits_dict):
|
||||
"""
|
||||
Create the dataset objects putting data
|
||||
Create the dataset objects putting data in the correct splits.
|
||||
|
||||
:param Collector collector: The collector object containing the data.
|
||||
:param dict splits_dict: The dictionary containing the number of
|
||||
elements in each stage.
|
||||
:return: The dictionary containing the dataset objects.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
# ----------- Auxiliary function ------------
|
||||
@@ -422,6 +479,15 @@ class PinaDataModule(LightningDataModule):
|
||||
return dataset_dict
|
||||
|
||||
def _create_dataloader(self, split, dataset):
|
||||
""" "
|
||||
Create the dataloader for the given split.
|
||||
|
||||
:param str split: The split on which to create the dataloader.
|
||||
:param str dataset: The dataset to be used for the dataloader.
|
||||
:return: The dataloader for the given split.
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
|
||||
shuffle = self.shuffle if split == "train" else False
|
||||
# Suppress the warning about num_workers.
|
||||
# In many cases, especially for PINNs,
|
||||
@@ -470,6 +536,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:return: The maximum length of the conditions.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
max_conditions_lengths = {}
|
||||
for k, v in self.collector_splits[split].items():
|
||||
if self.batch_size is None:
|
||||
@@ -484,7 +551,10 @@ class PinaDataModule(LightningDataModule):
|
||||
|
||||
def val_dataloader(self):
|
||||
"""
|
||||
Create the validation dataloader
|
||||
Create the validation dataloader.
|
||||
|
||||
:return: The validation dataloader
|
||||
:rtype: DataLoader
|
||||
"""
|
||||
return self._create_dataloader("val", self.val_dataset)
|
||||
|
||||
@@ -509,20 +579,17 @@ class PinaDataModule(LightningDataModule):
|
||||
@staticmethod
|
||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
This method is used when the batch size is None: batch has already
|
||||
been transferred to the device.
|
||||
Transfer the batch to the device. This method is used when the batch
|
||||
size is None: batch has already been transferred to the device.
|
||||
|
||||
:param list(tuple) batch: list of tuple where the first element of the
|
||||
tuple is the condition name and the second element is the data.
|
||||
:param device: device to which the batch is transferred
|
||||
:type device: torch.device
|
||||
:param dataloader_idx: index of the dataloader
|
||||
:type dataloader_idx: int
|
||||
:param torch.device device: device to which the batch is transferred.
|
||||
:param int dataloader_idx: index of the dataloader.
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
"""
|
||||
|
||||
return batch
|
||||
|
||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
@@ -531,12 +598,13 @@ class PinaDataModule(LightningDataModule):
|
||||
training loop and is used to transfer the batch to the device.
|
||||
|
||||
:param dict batch: The batch to be transferred to the device.
|
||||
:param device: The device to which the batch is transferred.
|
||||
:type device: torch.device
|
||||
:param torch.device device: The device to which the batch is
|
||||
transferred.
|
||||
:param int dataloader_idx: The index of the dataloader.
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
"""
|
||||
|
||||
batch = [
|
||||
(
|
||||
k,
|
||||
@@ -552,8 +620,18 @@ class PinaDataModule(LightningDataModule):
|
||||
@staticmethod
|
||||
def _check_slit_sizes(train_size, test_size, val_size):
|
||||
"""
|
||||
Check if the splits are correct
|
||||
Check if the splits are correct. The splits sizes must be positive and
|
||||
the sum of the splits must be 1.
|
||||
|
||||
:param float train_size: The size of the training split.
|
||||
:param float test_size: The size of the testing split.
|
||||
:param float val_size: The size of the validation split.
|
||||
|
||||
:raises ValueError: If at least one of the splits is negative.
|
||||
:raises ValueError: If the sum of the splits is different
|
||||
from 1.
|
||||
"""
|
||||
|
||||
if train_size < 0 or test_size < 0 or val_size < 0:
|
||||
raise ValueError("The splits must be positive")
|
||||
if abs(train_size + test_size + val_size - 1) > 1e-6:
|
||||
@@ -567,6 +645,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:return: The input points for training.
|
||||
:rtype dict
|
||||
"""
|
||||
|
||||
to_return = {}
|
||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||
to_return["train"] = self.train_dataset.input
|
||||
|
||||
Reference in New Issue
Block a user