Other fixes
This commit is contained in:
committed by
Nicola Demo
parent
ae796ce34c
commit
f587c3bf65
@@ -15,26 +15,26 @@ class PinaDatasetFactory:
|
||||
Depending on the type inside the conditions, it creates a different dataset
|
||||
object:
|
||||
|
||||
- :class:`pina.data.dataset.PinaTensorDataset` for handling
|
||||
:class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data.
|
||||
- :class:`pina.data.dataset.PinaGraphDataset` for handling
|
||||
:class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||
- :class:`~pina.data.dataset.PinaTensorDataset` for handling
|
||||
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
|
||||
- :class:`~pina.data.dataset.PinaGraphDataset` for handling
|
||||
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||
"""
|
||||
|
||||
def __new__(cls, conditions_dict, **kwargs):
|
||||
"""
|
||||
Instantiate the appropriate subclass of
|
||||
:class:`pina.data.dataset.PinaDataset`.
|
||||
:class:`~pina.data.dataset.PinaDataset`.
|
||||
|
||||
If a graph is present in the conditions, returns a
|
||||
:class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
||||
:class:`pina.data.dataset.PinaTensorDataset`.
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
||||
:class:`~pina.data.dataset.PinaTensorDataset`.
|
||||
|
||||
:param dict conditions_dict: Dictionary containing all the conditions
|
||||
to be included in the dataset instance.
|
||||
:return: A subclass of :class:`pina.data.dataset.PinaDataset`.
|
||||
:rtype: :class:`pina.data.dataset.PinaTensorDataset` |
|
||||
:class:`pina.data.dataset.PinaGraphDataset`
|
||||
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
|
||||
:rtype: :class:`~pina.data.dataset.PinaTensorDataset` |
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`
|
||||
|
||||
:raises ValueError: If an empty dictionary is provided.
|
||||
"""
|
||||
@@ -75,15 +75,15 @@ class PinaDatasetFactory:
|
||||
class PinaDataset(Dataset, ABC):
|
||||
"""
|
||||
Abstract class for the PINA dataset. It defines the common interface for
|
||||
the :class:`pina.data.dataset.PinaTensorDataset` and
|
||||
:class:`pina.data.dataset.PinaGraphDataset` classes.
|
||||
the :class:`~pina.data.dataset.PinaTensorDataset` and
|
||||
:class:`~pina.data.dataset.PinaGraphDataset` classes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||
):
|
||||
"""
|
||||
Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing
|
||||
Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing
|
||||
the providedconditions dictionary, the maximum number of conditions to
|
||||
consider, and the automatic batching flag.
|
||||
|
||||
@@ -93,7 +93,7 @@ class PinaDataset(Dataset, ABC):
|
||||
points to include in a single batch for each condition.
|
||||
:param bool automatic_batching: Indicates whether PyTorch automatic
|
||||
batching is enabled in
|
||||
:class:`pina.data.data_module.PinaDataModule`.
|
||||
:class:`~pina.data.data_module.PinaDataModule`.
|
||||
"""
|
||||
|
||||
# Store the conditions dictionary
|
||||
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
||||
:class:`pina.label_tensor.LabelTensor` data.
|
||||
:class:`~pina.label_tensor.LabelTensor` data.
|
||||
"""
|
||||
|
||||
# Override _retrive_data method for torch.Tensor data
|
||||
@@ -215,7 +215,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
|
||||
:param dict data: Dictionary containing the data
|
||||
(only :class:`torch.Tensor` or
|
||||
:class:`pina.label_tensor.LabelTensor`).
|
||||
:class:`~pina.label_tensor.LabelTensor`).
|
||||
:param list[int] idx_list: indices to retrieve.
|
||||
:return: Dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
@@ -237,7 +237,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
|
||||
and :class:`pina.graph.Graph` data.
|
||||
and :class:`~pina.graph.Graph` data.
|
||||
"""
|
||||
|
||||
def _create_graph_batch(self, data):
|
||||
|
||||
Reference in New Issue
Block a user