Fix doc data
This commit is contained in:
committed by
Nicola Demo
parent
6f97799284
commit
59e6ee595c
@@ -17,7 +17,7 @@ from ..collector import Collector
|
|||||||
|
|
||||||
class DummyDataloader:
|
class DummyDataloader:
|
||||||
"""
|
"""
|
||||||
Dataloader used when batch size is ``None``. It returns the entire dataset
|
Dataloader used when batch size is `None`. It returns the entire dataset
|
||||||
in a single batch.
|
in a single batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ class DummyDataloader:
|
|||||||
:param dataset: The dataset object to be processed.
|
:param dataset: The dataset object to be processed.
|
||||||
:type dataset: PinaDataset
|
:type dataset: PinaDataset
|
||||||
|
|
||||||
.. note:: This data loader is used when the batch size is ``None``.
|
.. note:: This data loader is used when the batch size is `None`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -157,7 +157,7 @@ class Collator:
|
|||||||
def _collate_tensor_dataset(data_list):
|
def _collate_tensor_dataset(data_list):
|
||||||
"""
|
"""
|
||||||
Function used to collate the data when the dataset is a
|
Function used to collate the data when the dataset is a
|
||||||
:class:`PinaTensorDataset`.
|
:class:`pina.data.dataset.PinaTensorDataset`.
|
||||||
|
|
||||||
:param data_list: Elements to be collated.
|
:param data_list: Elements to be collated.
|
||||||
:type data_list: list[torch.Tensor] | list[LabelTensor]
|
:type data_list: list[torch.Tensor] | list[LabelTensor]
|
||||||
@@ -165,7 +165,7 @@ class Collator:
|
|||||||
:rtype: dict
|
:rtype: dict
|
||||||
|
|
||||||
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
||||||
:class:`LabelTensor`.
|
:class:`pina.label_tensor.LabelTensor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
@@ -177,15 +177,15 @@ class Collator:
|
|||||||
def _collate_graph_dataset(self, data_list):
|
def _collate_graph_dataset(self, data_list):
|
||||||
"""
|
"""
|
||||||
Function used to collate the data when the dataset is a
|
Function used to collate the data when the dataset is a
|
||||||
:class:`PinaGraphDataset`.
|
:class:`pina.data.dataset.PinaGraphDataset`.
|
||||||
|
|
||||||
:param data_list: Elememts to be collated.
|
:param data_list: Elememts to be collated.
|
||||||
:type data_list: list[torch_geometric.data.Data] | list[Graph]
|
:type data_list: list[Data] | list[Graph]
|
||||||
:return: Batch of data.
|
:return: Batch of data.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
|
|
||||||
:raises RuntimeError: If the data is not a
|
:raises RuntimeError: If the data is not a
|
||||||
:class:`torch_geometric.data.Data` or a :class:`Graph`.
|
:class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
@@ -248,7 +248,7 @@ class PinaSampler:
|
|||||||
|
|
||||||
class PinaDataModule(LightningDataModule):
|
class PinaDataModule(LightningDataModule):
|
||||||
"""
|
"""
|
||||||
This class extends :class:`pytorch_lightning.LightningDataModule`,
|
This class extends :class:`lightning.pytorch.LightningDataModule`,
|
||||||
allowing proper creation and management of different types of datasets
|
allowing proper creation and management of different types of datasets
|
||||||
defined in PINA.
|
defined in PINA.
|
||||||
"""
|
"""
|
||||||
@@ -536,8 +536,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
Define the maximum length of the conditions.
|
Define the maximum length of the conditions.
|
||||||
|
|
||||||
:param split: The splits of the dataset.
|
:param dict split: The splits of the dataset.
|
||||||
:type split: dict
|
|
||||||
:return: The maximum length of the conditions.
|
:return: The maximum length of the conditions.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
@@ -559,7 +558,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
Create the validation dataloader.
|
Create the validation dataloader.
|
||||||
|
|
||||||
:return: The validation dataloader
|
:return: The validation dataloader
|
||||||
:rtype: DataLoader
|
:rtype: torch.utils.data.DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("val", self.val_dataset)
|
return self._create_dataloader("val", self.val_dataset)
|
||||||
|
|
||||||
@@ -568,7 +567,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
Create the training dataloader
|
Create the training dataloader
|
||||||
|
|
||||||
:return: The training dataloader
|
:return: The training dataloader
|
||||||
:rtype: DataLoader
|
:rtype: torch.utils.data.DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("train", self.train_dataset)
|
return self._create_dataloader("train", self.train_dataset)
|
||||||
|
|
||||||
@@ -577,7 +576,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
Create the testing dataloader
|
Create the testing dataloader
|
||||||
|
|
||||||
:return: The testing dataloader
|
:return: The testing dataloader
|
||||||
:rtype: DataLoader
|
:rtype: torch.utils.data.DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("test", self.test_dataset)
|
return self._create_dataloader("test", self.test_dataset)
|
||||||
|
|
||||||
|
|||||||
@@ -15,24 +15,26 @@ class PinaDatasetFactory:
|
|||||||
Depending on the type inside the conditions, it creates a different dataset
|
Depending on the type inside the conditions, it creates a different dataset
|
||||||
object:
|
object:
|
||||||
|
|
||||||
- :class:`PinaTensorDataset` for handling :class:`torch.Tensor` and
|
- :class:`pina.data.dataset.PinaTensorDataset` for handling
|
||||||
:class:`LabelTensor` data.
|
:class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data.
|
||||||
- :class:`PinaGraphDataset` for handling :class:`Graph` and :class:`Data`
|
- :class:`pina.data.dataset.PinaGraphDataset` for handling
|
||||||
data.
|
:class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, conditions_dict, **kwargs):
|
def __new__(cls, conditions_dict, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate the appropriate subclass of :class:`PinaDataset`.
|
Instantiate the appropriate subclass of
|
||||||
|
:class:`pina.data.dataset.PinaDataset`.
|
||||||
|
|
||||||
If a graph is present in the conditions, returns a
|
If a graph is present in the conditions, returns a
|
||||||
:class:`PinaGraphDataset`, otherwise returns a
|
:class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
||||||
:class:`PinaTensorDataset`.
|
:class:`pina.data.dataset.PinaTensorDataset`.
|
||||||
|
|
||||||
:param dict conditions_dict: Dictionary containing all the conditions
|
:param dict conditions_dict: Dictionary containing all the conditions
|
||||||
to be included in the dataset instance.
|
to be included in the dataset instance.
|
||||||
:return: A subclass of :class:`PinaDataset`.
|
:return: A subclass of :class:`pina.data.dataset.PinaDataset`.
|
||||||
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
|
:rtype: :class:`pina.data.dataset.PinaTensorDataset` |
|
||||||
|
:class:`pina.data.dataset.PinaGraphDataset`
|
||||||
|
|
||||||
:raises ValueError: If an empty dictionary is provided.
|
:raises ValueError: If an empty dictionary is provided.
|
||||||
"""
|
"""
|
||||||
@@ -73,25 +75,25 @@ class PinaDatasetFactory:
|
|||||||
class PinaDataset(Dataset, ABC):
|
class PinaDataset(Dataset, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for the PINA dataset. It defines the common interface for
|
Abstract class for the PINA dataset. It defines the common interface for
|
||||||
the :class:`PinaTensorDataset` and :class:`PinaGraphDataset` classes.
|
the :class:`pina.data.dataset.PinaTensorDataset` and
|
||||||
|
:class:`pina.data.dataset.PinaGraphDataset` classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a :class:`PinaDataset` instance by storing the provided
|
Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing
|
||||||
conditions dictionary, the maximum number of conditions to consider,
|
the providedconditions dictionary, the maximum number of conditions to
|
||||||
and the automatic batching flag.
|
consider, and the automatic batching flag.
|
||||||
|
|
||||||
:param conditions_dict: Dictionary containing the conditions.
|
:param dict conditions_dict: Dictionary containing the conditions with
|
||||||
:type conditions_dict: dict
|
data.
|
||||||
:param max_conditions_lengths: Specifies the maximum number of data
|
:param dict max_conditions_lengths: Specifies the maximum number of data
|
||||||
points to include in a single batch for each condition.
|
points to include in a single batch for each condition.
|
||||||
:type max_conditions_lengths: dict
|
:param bool automatic_batching: Indicates whether PyTorch automatic
|
||||||
:param automatic_batching: Indicates whether PyTorch automatic batching
|
batching is enabled in
|
||||||
is enabled in :class:`PinaDataModule`.
|
:class:`pina.data.data_module.PinaDataModule`.
|
||||||
:type automatic_batching: bool
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Store the conditions dictionary
|
# Store the conditions dictionary
|
||||||
@@ -134,8 +136,7 @@ class PinaDataset(Dataset, ABC):
|
|||||||
Return the index itself. This is used when automatic batching is
|
Return the index itself. This is used when automatic batching is
|
||||||
disabled to postpone the data retrieval to the dataloader.
|
disabled to postpone the data retrieval to the dataloader.
|
||||||
|
|
||||||
:param idx: Index.
|
:param int idx: Index.
|
||||||
:type idx: int
|
|
||||||
:return: Index.
|
:return: Index.
|
||||||
:rtype: int
|
:rtype: int
|
||||||
"""
|
"""
|
||||||
@@ -174,8 +175,7 @@ class PinaDataset(Dataset, ABC):
|
|||||||
"""
|
"""
|
||||||
Return data from the dataset given a list of indices.
|
Return data from the dataset given a list of indices.
|
||||||
|
|
||||||
:param idx: List of indices.
|
:param list[int] idx: List of indices.
|
||||||
:type idx: list[int]
|
|
||||||
:return: A dictionary containing the data at the given indices.
|
:return: A dictionary containing the data at the given indices.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
|
|||||||
class PinaTensorDataset(PinaDataset):
|
class PinaTensorDataset(PinaDataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
||||||
:class:`LabelTensor` data.
|
:class:`pina.label_tensor.LabelTensor` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Override _retrive_data method for torch.Tensor data
|
# Override _retrive_data method for torch.Tensor data
|
||||||
@@ -213,9 +213,9 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
"""
|
"""
|
||||||
Retrieve data from the dataset given a list of indices.
|
Retrieve data from the dataset given a list of indices.
|
||||||
|
|
||||||
:param data: Dictionary containing the data
|
:param dict data: Dictionary containing the data
|
||||||
(only torch.Tensor/LableTensor).
|
(only :class:`torch.Tensor` or
|
||||||
:type data: dict
|
:class:`pina.label_tensor.LabelTensor`).
|
||||||
:param list[int] idx_list: indices to retrieve.
|
:param list[int] idx_list: indices to retrieve.
|
||||||
:return: Dictionary containing the data at the given indices.
|
:return: Dictionary containing the data at the given indices.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
@@ -236,17 +236,17 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
|
|
||||||
class PinaGraphDataset(PinaDataset):
|
class PinaGraphDataset(PinaDataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for the PINA dataset with :class:`torch_geometric.data.Data`
|
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
|
||||||
and :class:`Graph` data.
|
and :class:`pina.graph.Graph` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _create_graph_batch(self, data):
|
def _create_graph_batch(self, data):
|
||||||
"""
|
"""
|
||||||
Create a LabelBatch object from a list of
|
Create a LabelBatch object from a list of
|
||||||
:class:`torch_geometric.data.Data` objects.
|
:class:`~torch_geometric.data.Data` objects.
|
||||||
|
|
||||||
:param data: List of items to collate in a single batch.
|
:param data: List of items to collate in a single batch.
|
||||||
:type data: list[torch_geometric.data.Data] | list[Graph]
|
:type data: list[Data] | list[Graph]
|
||||||
:return: LabelBatch object all the graph collated in a single batch
|
:return: LabelBatch object all the graph collated in a single batch
|
||||||
disconnected graphs.
|
disconnected graphs.
|
||||||
:rtype: LabelBatch
|
:rtype: LabelBatch
|
||||||
@@ -256,13 +256,13 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
|
|
||||||
def _create_tensor_batch(self, data):
|
def _create_tensor_batch(self, data):
|
||||||
"""
|
"""
|
||||||
Reshape properly ``data`` tensor to be processed handle by the graph
|
Reshape properly `data` tensor to be processed handle by the graph
|
||||||
based models.
|
based models.
|
||||||
|
|
||||||
:param data: torch.Tensor object of shape (N, ...) where N is the
|
:param data: torch.Tensor object of shape (N, ...) where N is the
|
||||||
number of data points.
|
number of data points.
|
||||||
:type data: torch.Tensor | LabelTensor
|
:type data: torch.Tensor | LabelTensor
|
||||||
:return: reshaped torch.Tensor or LabelTensor object.
|
:return: Reshaped tensor object.
|
||||||
:rtype: torch.Tensor | LabelTensor
|
:rtype: torch.Tensor | LabelTensor
|
||||||
"""
|
"""
|
||||||
out = data.reshape(-1, *data.shape[2:])
|
out = data.reshape(-1, *data.shape[2:])
|
||||||
@@ -270,11 +270,11 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
|
|
||||||
def create_batch(self, data):
|
def create_batch(self, data):
|
||||||
"""
|
"""
|
||||||
Create a Batch object from a list of :class:`torch_geometric.data.Data`
|
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||||
objects.
|
objects.
|
||||||
|
|
||||||
:param data: List of items to collate in a single batch.
|
:param data: List of items to collate in a single batch.
|
||||||
:type data: list[torch_geometric.data.Data] | list[Graph]
|
:type data: list[Data] | list[Graph]
|
||||||
:return: Batch object.
|
:return: Batch object.
|
||||||
:rtype: Batch | PinaBatch
|
:rtype: Batch | PinaBatch
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user