Fix doc data
This commit is contained in:
committed by
Nicola Demo
parent
d411543b76
commit
7dd954ce50
@@ -176,7 +176,7 @@ class Collator:
|
|||||||
|
|
||||||
def _collate_graph_dataset(self, data_list):
|
def _collate_graph_dataset(self, data_list):
|
||||||
"""
|
"""
|
||||||
Function used to collate the data when the dataset is a
|
Function used to collate data when the dataset is a
|
||||||
:class:`~pina.data.dataset.PinaGraphDataset`.
|
:class:`~pina.data.dataset.PinaGraphDataset`.
|
||||||
|
|
||||||
:param data_list: Elememts to be collated.
|
:param data_list: Elememts to be collated.
|
||||||
@@ -187,7 +187,6 @@ class Collator:
|
|||||||
:raises RuntimeError: If the data is not a
|
:raises RuntimeError: If the data is not a
|
||||||
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
|
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
return LabelTensor.cat(data_list)
|
return LabelTensor.cat(data_list)
|
||||||
if isinstance(data_list[0], torch.Tensor):
|
if isinstance(data_list[0], torch.Tensor):
|
||||||
@@ -201,14 +200,13 @@ class Collator:
|
|||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
"""
|
"""
|
||||||
Perform the collation of the data points fetched from the dataset.
|
Perform the collation of data fetched from the dataset. The behavoior
|
||||||
The behavoior of the function is set based on the batching strategy
|
of the function is set based on the batching strategy during class
|
||||||
during class initialization.
|
initialization.
|
||||||
|
|
||||||
:param batch: List of retrieved data or sampled indices.
|
:param batch: List of retrieved data or sampled indices.
|
||||||
:type batch: list[int] | list[dict]
|
:type batch: list[int] | list[dict]
|
||||||
:return: Dictionary containing the data points fetched from the dataset,
|
:return: Dictionary containing colleted data fetched from the dataset.
|
||||||
collated.
|
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -223,12 +221,10 @@ class PinaSampler:
|
|||||||
|
|
||||||
def __new__(cls, dataset, shuffle):
|
def __new__(cls, dataset, shuffle):
|
||||||
"""
|
"""
|
||||||
Instantiate the sampler based on the environment in which the code is
|
Instantiate and initialize the sampler.
|
||||||
running.
|
|
||||||
|
|
||||||
:param PinaDataset dataset: The dataset to be sampled.
|
:param PinaDataset dataset: The dataset from which to sample.
|
||||||
:param bool shuffle: whether to shuffle the dataset or not before
|
:param bool shuffle: Whether to shuffle the dataset.
|
||||||
sampling.
|
|
||||||
:return: The sampler instance.
|
:return: The sampler instance.
|
||||||
:rtype: torch.utils.data.Sampler
|
:rtype: torch.utils.data.Sampler
|
||||||
"""
|
"""
|
||||||
@@ -267,18 +263,18 @@ class PinaDataModule(LightningDataModule):
|
|||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the object, creating datasets based on the input problem.
|
Initialize the object and creating datasets based on the input problem.
|
||||||
|
|
||||||
:param AbstractProblem problem: The problem containing the data on which
|
:param AbstractProblem problem: The problem containing the data on which
|
||||||
to create the datasets and dataloaders.
|
to create the datasets and dataloaders.
|
||||||
:param float train_size: Fraction or number of elements in the training
|
:param float train_size: Fraction of elements in the training split. It
|
||||||
split. It must be in the range [0, 1].
|
must be in the range [0, 1].
|
||||||
:param float test_size: Fraction or number of elements in the test
|
:param float test_size: Fraction of elements in the test split. It must
|
||||||
split. It must be in the range [0, 1].
|
be in the range [0, 1].
|
||||||
:param float val_size: Fraction or number of elements in the validation
|
:param float val_size: Fraction of elements in the validation split. It
|
||||||
split. It must be in the range [0, 1].
|
must be in the range [0, 1].
|
||||||
:param batch_size: The batch size used for training. If `None`, the
|
:param batch_size: The batch size used for training. If `None`, the
|
||||||
entire dataset is used per batch.
|
entire dataset is returned in a single batch.
|
||||||
:type batch_size: int | None
|
:type batch_size: int | None
|
||||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||||
Default True.
|
Default True.
|
||||||
@@ -289,7 +285,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:param int num_workers: Number of worker threads for data loading.
|
:param int num_workers: Number of worker threads for data loading.
|
||||||
Default 0 (serial loading).
|
Default 0 (serial loading).
|
||||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
transfer to GPU. (Default False).
|
transfer to GPU. Default False.
|
||||||
|
|
||||||
:raises ValueError: If at least one of the splits is negative.
|
:raises ValueError: If at least one of the splits is negative.
|
||||||
:raises ValueError: If the sum of the splits is different from 1.
|
:raises ValueError: If the sum of the splits is different from 1.
|
||||||
@@ -370,7 +366,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
If the stage is "fit", the training and validation datasets are created.
|
If the stage is "fit", the training and validation datasets are created.
|
||||||
If the stage is "test", the testing dataset is created.
|
If the stage is "test", the testing dataset is created.
|
||||||
|
|
||||||
:param str stage: The stage for which to perform the splitting.
|
:param str stage: The stage for which to perform the dataset setup.
|
||||||
|
|
||||||
:raises ValueError: If the stage is neither "fit" nor "test".
|
:raises ValueError: If the stage is neither "fit" nor "test".
|
||||||
"""
|
"""
|
||||||
@@ -534,10 +530,10 @@ class PinaDataModule(LightningDataModule):
|
|||||||
|
|
||||||
def find_max_conditions_lengths(self, split):
|
def find_max_conditions_lengths(self, split):
|
||||||
"""
|
"""
|
||||||
Define the maximum length of the conditions.
|
Define the maximum length for each conditions.
|
||||||
|
|
||||||
:param dict split: The splits of the dataset.
|
:param dict split: The split of the dataset.
|
||||||
:return: The maximum length of the conditions.
|
:return: The maximum length per condition.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class PinaDatasetFactory:
|
|||||||
class PinaDataset(Dataset, ABC):
|
class PinaDataset(Dataset, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for the PINA dataset. It defines the common interface for
|
Abstract class for the PINA dataset. It defines the common interface for
|
||||||
the :class:`~pina.data.dataset.PinaTensorDataset` and
|
:class:`~pina.data.dataset.PinaTensorDataset` and
|
||||||
:class:`~pina.data.dataset.PinaGraphDataset` classes.
|
:class:`~pina.data.dataset.PinaGraphDataset` classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -83,9 +83,8 @@ class PinaDataset(Dataset, ABC):
|
|||||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing
|
Initialize :class:`~pina.data.dataset.PinaDataset` instance by storing
|
||||||
the providedconditions dictionary, the maximum number of conditions to
|
the provided conditions dictionary, and the automatic batching flag.
|
||||||
consider, and the automatic batching flag.
|
|
||||||
|
|
||||||
:param dict conditions_dict: Dictionary containing the conditions with
|
:param dict conditions_dict: Dictionary containing the conditions with
|
||||||
data.
|
data.
|
||||||
|
|||||||
Reference in New Issue
Block a user