Fix doc data
This commit is contained in:
committed by
Nicola Demo
parent
d411543b76
commit
7dd954ce50
@@ -176,7 +176,7 @@ class Collator:
|
||||
|
||||
def _collate_graph_dataset(self, data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
Function used to collate data when the dataset is a
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`.
|
||||
|
||||
:param data_list: Elememts to be collated.
|
||||
@@ -187,7 +187,6 @@ class Collator:
|
||||
:raises RuntimeError: If the data is not a
|
||||
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.cat(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
@@ -201,14 +200,13 @@ class Collator:
|
||||
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Perform the collation of the data points fetched from the dataset.
|
||||
The behavoior of the function is set based on the batching strategy
|
||||
during class initialization.
|
||||
Perform the collation of data fetched from the dataset. The behavoior
|
||||
of the function is set based on the batching strategy during class
|
||||
initialization.
|
||||
|
||||
:param batch: List of retrieved data or sampled indices.
|
||||
:type batch: list[int] | list[dict]
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
:return: Dictionary containing colleted data fetched from the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
@@ -223,12 +221,10 @@ class PinaSampler:
|
||||
|
||||
def __new__(cls, dataset, shuffle):
|
||||
"""
|
||||
Instantiate the sampler based on the environment in which the code is
|
||||
running.
|
||||
Instantiate and initialize the sampler.
|
||||
|
||||
:param PinaDataset dataset: The dataset to be sampled.
|
||||
:param bool shuffle: whether to shuffle the dataset or not before
|
||||
sampling.
|
||||
:param PinaDataset dataset: The dataset from which to sample.
|
||||
:param bool shuffle: Whether to shuffle the dataset.
|
||||
:return: The sampler instance.
|
||||
:rtype: torch.utils.data.Sampler
|
||||
"""
|
||||
@@ -267,18 +263,18 @@ class PinaDataModule(LightningDataModule):
|
||||
pin_memory=False,
|
||||
):
|
||||
"""
|
||||
Initialize the object, creating datasets based on the input problem.
|
||||
Initialize the object and creating datasets based on the input problem.
|
||||
|
||||
:param AbstractProblem problem: The problem containing the data on which
|
||||
to create the datasets and dataloaders.
|
||||
:param float train_size: Fraction or number of elements in the training
|
||||
split. It must be in the range [0, 1].
|
||||
:param float test_size: Fraction or number of elements in the test
|
||||
split. It must be in the range [0, 1].
|
||||
:param float val_size: Fraction or number of elements in the validation
|
||||
split. It must be in the range [0, 1].
|
||||
:param float train_size: Fraction of elements in the training split. It
|
||||
must be in the range [0, 1].
|
||||
:param float test_size: Fraction of elements in the test split. It must
|
||||
be in the range [0, 1].
|
||||
:param float val_size: Fraction of elements in the validation split. It
|
||||
must be in the range [0, 1].
|
||||
:param batch_size: The batch size used for training. If `None`, the
|
||||
entire dataset is used per batch.
|
||||
entire dataset is returned in a single batch.
|
||||
:type batch_size: int | None
|
||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||
Default True.
|
||||
@@ -289,7 +285,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:param int num_workers: Number of worker threads for data loading.
|
||||
Default 0 (serial loading).
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. (Default False).
|
||||
transfer to GPU. Default False.
|
||||
|
||||
:raises ValueError: If at least one of the splits is negative.
|
||||
:raises ValueError: If the sum of the splits is different from 1.
|
||||
@@ -370,7 +366,7 @@ class PinaDataModule(LightningDataModule):
|
||||
If the stage is "fit", the training and validation datasets are created.
|
||||
If the stage is "test", the testing dataset is created.
|
||||
|
||||
:param str stage: The stage for which to perform the splitting.
|
||||
:param str stage: The stage for which to perform the dataset setup.
|
||||
|
||||
:raises ValueError: If the stage is neither "fit" nor "test".
|
||||
"""
|
||||
@@ -534,10 +530,10 @@ class PinaDataModule(LightningDataModule):
|
||||
|
||||
def find_max_conditions_lengths(self, split):
|
||||
"""
|
||||
Define the maximum length of the conditions.
|
||||
Define the maximum length for each conditions.
|
||||
|
||||
:param dict split: The splits of the dataset.
|
||||
:return: The maximum length of the conditions.
|
||||
:param dict split: The split of the dataset.
|
||||
:return: The maximum length per condition.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user