Other fixes

This commit is contained in:
FilippoOlivo
2025-03-12 12:29:18 +01:00
parent d857b47002
commit 033d36c5a8
9 changed files with 85 additions and 82 deletions

View File

@@ -15,26 +15,26 @@ class PinaDatasetFactory:
Depending on the type inside the conditions, it creates a different dataset
object:
- :class:`pina.data.dataset.PinaTensorDataset` for handling
:class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data.
- :class:`pina.data.dataset.PinaGraphDataset` for handling
:class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
- :class:`~pina.data.dataset.PinaTensorDataset` for handling
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
- :class:`~pina.data.dataset.PinaGraphDataset` for handling
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
"""
def __new__(cls, conditions_dict, **kwargs):
"""
Instantiate the appropriate subclass of
:class:`pina.data.dataset.PinaDataset`.
:class:`~pina.data.dataset.PinaDataset`.
If a graph is present in the conditions, returns a
:class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a
:class:`pina.data.dataset.PinaTensorDataset`.
:class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
:class:`~pina.data.dataset.PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance.
:return: A subclass of :class:`pina.data.dataset.PinaDataset`.
:rtype: :class:`pina.data.dataset.PinaTensorDataset` |
:class:`pina.data.dataset.PinaGraphDataset`
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
:rtype: :class:`~pina.data.dataset.PinaTensorDataset` |
:class:`~pina.data.dataset.PinaGraphDataset`
:raises ValueError: If an empty dictionary is provided.
"""
@@ -75,15 +75,15 @@ class PinaDatasetFactory:
class PinaDataset(Dataset, ABC):
"""
Abstract class for the PINA dataset. It defines the common interface for
the :class:`pina.data.dataset.PinaTensorDataset` and
:class:`pina.data.dataset.PinaGraphDataset` classes.
the :class:`~pina.data.dataset.PinaTensorDataset` and
:class:`~pina.data.dataset.PinaGraphDataset` classes.
"""
def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching
):
"""
Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing
Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing
the providedconditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag.
@@ -93,7 +93,7 @@ class PinaDataset(Dataset, ABC):
points to include in a single batch for each condition.
:param bool automatic_batching: Indicates whether PyTorch automatic
batching is enabled in
:class:`pina.data.data_module.PinaDataModule`.
:class:`~pina.data.data_module.PinaDataModule`.
"""
# Store the conditions dictionary
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
class PinaTensorDataset(PinaDataset):
"""
Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`pina.label_tensor.LabelTensor` data.
:class:`~pina.label_tensor.LabelTensor` data.
"""
# Override _retrive_data method for torch.Tensor data
@@ -215,7 +215,7 @@ class PinaTensorDataset(PinaDataset):
:param dict data: Dictionary containing the data
(only :class:`torch.Tensor` or
:class:`pina.label_tensor.LabelTensor`).
:class:`~pina.label_tensor.LabelTensor`).
:param list[int] idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
@@ -237,7 +237,7 @@ class PinaTensorDataset(PinaDataset):
class PinaGraphDataset(PinaDataset):
"""
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
and :class:`pina.graph.Graph` data.
and :class:`~pina.graph.Graph` data.
"""
def _create_graph_batch(self, data):