Fix doc data
This commit is contained in:
@@ -17,7 +17,7 @@ from ..collector import Collector
|
||||
|
||||
class DummyDataloader:
|
||||
"""
|
||||
Dataloader used when batch size is ``None``. It returns the entire dataset
|
||||
Dataloader used when batch size is `None`. It returns the entire dataset
|
||||
in a single batch.
|
||||
"""
|
||||
|
||||
@@ -38,7 +38,7 @@ class DummyDataloader:
|
||||
:param dataset: The dataset object to be processed.
|
||||
:type dataset: PinaDataset
|
||||
|
||||
.. note:: This data loader is used when the batch size is ``None``.
|
||||
.. note:: This data loader is used when the batch size is `None`.
|
||||
"""
|
||||
|
||||
if (
|
||||
@@ -157,7 +157,7 @@ class Collator:
|
||||
def _collate_tensor_dataset(data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
:class:`PinaTensorDataset`.
|
||||
:class:`pina.data.dataset.PinaTensorDataset`.
|
||||
|
||||
:param data_list: Elements to be collated.
|
||||
:type data_list: list[torch.Tensor] | list[LabelTensor]
|
||||
@@ -165,7 +165,7 @@ class Collator:
|
||||
:rtype: dict
|
||||
|
||||
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
||||
:class:`LabelTensor`.
|
||||
:class:`pina.label_tensor.LabelTensor`.
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
@@ -177,15 +177,15 @@ class Collator:
|
||||
def _collate_graph_dataset(self, data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
:class:`PinaGraphDataset`.
|
||||
:class:`pina.data.dataset.PinaGraphDataset`.
|
||||
|
||||
:param data_list: Elememts to be collated.
|
||||
:type data_list: list[torch_geometric.data.Data] | list[Graph]
|
||||
:type data_list: list[Data] | list[Graph]
|
||||
:return: Batch of data.
|
||||
:rtype: dict
|
||||
|
||||
:raises RuntimeError: If the data is not a
|
||||
:class:`torch_geometric.data.Data` or a :class:`Graph`.
|
||||
:class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`.
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
@@ -248,7 +248,7 @@ class PinaSampler:
|
||||
|
||||
class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
This class extends :class:`pytorch_lightning.LightningDataModule`,
|
||||
This class extends :class:`lightning.pytorch.LightningDataModule`,
|
||||
allowing proper creation and management of different types of datasets
|
||||
defined in PINA.
|
||||
"""
|
||||
@@ -536,8 +536,7 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
Define the maximum length of the conditions.
|
||||
|
||||
:param split: The splits of the dataset.
|
||||
:type split: dict
|
||||
:param dict split: The splits of the dataset.
|
||||
:return: The maximum length of the conditions.
|
||||
:rtype: dict
|
||||
"""
|
||||
@@ -559,7 +558,7 @@ class PinaDataModule(LightningDataModule):
|
||||
Create the validation dataloader.
|
||||
|
||||
:return: The validation dataloader
|
||||
:rtype: DataLoader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
return self._create_dataloader("val", self.val_dataset)
|
||||
|
||||
@@ -568,7 +567,7 @@ class PinaDataModule(LightningDataModule):
|
||||
Create the training dataloader
|
||||
|
||||
:return: The training dataloader
|
||||
:rtype: DataLoader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
return self._create_dataloader("train", self.train_dataset)
|
||||
|
||||
@@ -577,7 +576,7 @@ class PinaDataModule(LightningDataModule):
|
||||
Create the testing dataloader
|
||||
|
||||
:return: The testing dataloader
|
||||
:rtype: DataLoader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
return self._create_dataloader("test", self.test_dataset)
|
||||
|
||||
|
||||
@@ -15,24 +15,26 @@ class PinaDatasetFactory:
|
||||
Depending on the type inside the conditions, it creates a different dataset
|
||||
object:
|
||||
|
||||
- :class:`PinaTensorDataset` for handling :class:`torch.Tensor` and
|
||||
:class:`LabelTensor` data.
|
||||
- :class:`PinaGraphDataset` for handling :class:`Graph` and :class:`Data`
|
||||
data.
|
||||
- :class:`pina.data.dataset.PinaTensorDataset` for handling
|
||||
:class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data.
|
||||
- :class:`pina.data.dataset.PinaGraphDataset` for handling
|
||||
:class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||
"""
|
||||
|
||||
def __new__(cls, conditions_dict, **kwargs):
|
||||
"""
|
||||
Instantiate the appropriate subclass of :class:`PinaDataset`.
|
||||
Instantiate the appropriate subclass of
|
||||
:class:`pina.data.dataset.PinaDataset`.
|
||||
|
||||
If a graph is present in the conditions, returns a
|
||||
:class:`PinaGraphDataset`, otherwise returns a
|
||||
:class:`PinaTensorDataset`.
|
||||
:class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
||||
:class:`pina.data.dataset.PinaTensorDataset`.
|
||||
|
||||
:param dict conditions_dict: Dictionary containing all the conditions
|
||||
to be included in the dataset instance.
|
||||
:return: A subclass of :class:`PinaDataset`.
|
||||
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
|
||||
:return: A subclass of :class:`pina.data.dataset.PinaDataset`.
|
||||
:rtype: :class:`pina.data.dataset.PinaTensorDataset` |
|
||||
:class:`pina.data.dataset.PinaGraphDataset`
|
||||
|
||||
:raises ValueError: If an empty dictionary is provided.
|
||||
"""
|
||||
@@ -73,25 +75,25 @@ class PinaDatasetFactory:
|
||||
class PinaDataset(Dataset, ABC):
|
||||
"""
|
||||
Abstract class for the PINA dataset. It defines the common interface for
|
||||
the :class:`PinaTensorDataset` and :class:`PinaGraphDataset` classes.
|
||||
the :class:`pina.data.dataset.PinaTensorDataset` and
|
||||
:class:`pina.data.dataset.PinaGraphDataset` classes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||
):
|
||||
"""
|
||||
Initialize a :class:`PinaDataset` instance by storing the provided
|
||||
conditions dictionary, the maximum number of conditions to consider,
|
||||
and the automatic batching flag.
|
||||
Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing
|
||||
the providedconditions dictionary, the maximum number of conditions to
|
||||
consider, and the automatic batching flag.
|
||||
|
||||
:param conditions_dict: Dictionary containing the conditions.
|
||||
:type conditions_dict: dict
|
||||
:param max_conditions_lengths: Specifies the maximum number of data
|
||||
:param dict conditions_dict: Dictionary containing the conditions with
|
||||
data.
|
||||
:param dict max_conditions_lengths: Specifies the maximum number of data
|
||||
points to include in a single batch for each condition.
|
||||
:type max_conditions_lengths: dict
|
||||
:param automatic_batching: Indicates whether PyTorch automatic batching
|
||||
is enabled in :class:`PinaDataModule`.
|
||||
:type automatic_batching: bool
|
||||
:param bool automatic_batching: Indicates whether PyTorch automatic
|
||||
batching is enabled in
|
||||
:class:`pina.data.data_module.PinaDataModule`.
|
||||
"""
|
||||
|
||||
# Store the conditions dictionary
|
||||
@@ -134,8 +136,7 @@ class PinaDataset(Dataset, ABC):
|
||||
Return the index itself. This is used when automatic batching is
|
||||
disabled to postpone the data retrieval to the dataloader.
|
||||
|
||||
:param idx: Index.
|
||||
:type idx: int
|
||||
:param int idx: Index.
|
||||
:return: Index.
|
||||
:rtype: int
|
||||
"""
|
||||
@@ -174,8 +175,7 @@ class PinaDataset(Dataset, ABC):
|
||||
"""
|
||||
Return data from the dataset given a list of indices.
|
||||
|
||||
:param idx: List of indices.
|
||||
:type idx: list[int]
|
||||
:param list[int] idx: List of indices.
|
||||
:return: A dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
"""
|
||||
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
||||
:class:`LabelTensor` data.
|
||||
:class:`pina.label_tensor.LabelTensor` data.
|
||||
"""
|
||||
|
||||
# Override _retrive_data method for torch.Tensor data
|
||||
@@ -213,9 +213,9 @@ class PinaTensorDataset(PinaDataset):
|
||||
"""
|
||||
Retrieve data from the dataset given a list of indices.
|
||||
|
||||
:param data: Dictionary containing the data
|
||||
(only torch.Tensor/LableTensor).
|
||||
:type data: dict
|
||||
:param dict data: Dictionary containing the data
|
||||
(only :class:`torch.Tensor` or
|
||||
:class:`pina.label_tensor.LabelTensor`).
|
||||
:param list[int] idx_list: indices to retrieve.
|
||||
:return: Dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
@@ -236,17 +236,17 @@ class PinaTensorDataset(PinaDataset):
|
||||
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`torch_geometric.data.Data`
|
||||
and :class:`Graph` data.
|
||||
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
|
||||
and :class:`pina.graph.Graph` data.
|
||||
"""
|
||||
|
||||
def _create_graph_batch(self, data):
|
||||
"""
|
||||
Create a LabelBatch object from a list of
|
||||
:class:`torch_geometric.data.Data` objects.
|
||||
:class:`~torch_geometric.data.Data` objects.
|
||||
|
||||
:param data: List of items to collate in a single batch.
|
||||
:type data: list[torch_geometric.data.Data] | list[Graph]
|
||||
:type data: list[Data] | list[Graph]
|
||||
:return: LabelBatch object all the graph collated in a single batch
|
||||
disconnected graphs.
|
||||
:rtype: LabelBatch
|
||||
@@ -256,13 +256,13 @@ class PinaGraphDataset(PinaDataset):
|
||||
|
||||
def _create_tensor_batch(self, data):
|
||||
"""
|
||||
Reshape properly ``data`` tensor to be processed handle by the graph
|
||||
Reshape properly `data` tensor to be processed handle by the graph
|
||||
based models.
|
||||
|
||||
:param data: torch.Tensor object of shape (N, ...) where N is the
|
||||
number of data points.
|
||||
:type data: torch.Tensor | LabelTensor
|
||||
:return: reshaped torch.Tensor or LabelTensor object.
|
||||
:return: Reshaped tensor object.
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
"""
|
||||
out = data.reshape(-1, *data.shape[2:])
|
||||
@@ -270,11 +270,11 @@ class PinaGraphDataset(PinaDataset):
|
||||
|
||||
def create_batch(self, data):
|
||||
"""
|
||||
Create a Batch object from a list of :class:`torch_geometric.data.Data`
|
||||
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||
objects.
|
||||
|
||||
:param data: List of items to collate in a single batch.
|
||||
:type data: list[torch_geometric.data.Data] | list[Graph]
|
||||
:type data: list[Data] | list[Graph]
|
||||
:return: Batch object.
|
||||
:rtype: Batch | PinaBatch
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user