This commit is contained in:
FilippoOlivo
2025-03-11 17:18:48 +01:00
parent 3fd12669bb
commit 72ce6edaa7
3 changed files with 193 additions and 106 deletions

View File

@@ -1,8 +1,8 @@
"""
This module provide basic data management functionalities
Module for the PINA dataset
"""
from abc import abstractmethod
from abc import abstractmethod, ABC
from torch.utils.data import Dataset
from torch_geometric.data import Data
from ..graph import Graph, LabelBatch
@@ -15,9 +15,10 @@ class PinaDatasetFactory:
Depending on the type inside the conditions, it creates a different dataset
object:
- :class:`PinaTensorDataset` for `torch.Tensor`
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
objects
- :class:`PinaTensorDataset` for handling :class:`torch.Tensor` and
:class:`LabelTensor` data.
- :class:`PinaGraphDataset` for handling :class:`Graph` and :class:`Data`
data.
"""
def __new__(cls, conditions_dict, **kwargs):
@@ -28,7 +29,8 @@ class PinaDatasetFactory:
:class:`PinaGraphDataset`, otherwise returns a
:class:`PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing the conditions.
:param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance.
:return: A subclass of :class:`PinaDataset`.
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
@@ -50,11 +52,11 @@ class PinaDatasetFactory:
@staticmethod
def _is_graph_dataset(conditions_dict):
"""
Check if a graph is present in the conditions.
Check if a graph is present in the conditions (at least one time).
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise
:return: True if a graph is present in the conditions, False otherwise.
:rtype: bool
"""
@@ -68,25 +70,28 @@ class PinaDatasetFactory:
return False
class PinaDataset(Dataset):
class PinaDataset(Dataset, ABC):
"""
Abstract class for the PINA dataset
Abstract class for the PINA dataset. It defines the common interface for
the :class:`PinaTensorDataset` and :class:`PinaGraphDataset` classes.
"""
def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching
):
"""
Initialize the :class:`PinaDataset`.
Initialize a :class:`PinaDataset` instance by storing the provided
conditions dictionary, the maximum number of conditions to consider,
and the automatic batching flag.
Stores the conditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag.
:param dict conditions_dict: Dictionary containing the conditions.
:param dict max_conditions_lengths: Maximum number of data points to
consider in a single batch for each condition.
:param bool automatic_batching: Whether PyTorch automatic batching is
enabled in :class:`PinaDataModule`.
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:param max_conditions_lengths: Specifies the maximum number of data
points to include in a single batch for each condition.
:type max_conditions_lengths: dict
:param automatic_batching: Indicates whether PyTorch automatic batching
is enabled in :class:`PinaDataModule`.
:type automatic_batching: bool
"""
# Store the conditions dictionary
@@ -107,9 +112,9 @@ class PinaDataset(Dataset):
def _get_max_len(self):
"""
Returns the length of the longest condition in the dataset
Returns the length of the longest condition in the dataset.
:return: Length of the longest condition in the dataset
:return: Length of the longest condition in the dataset.
:rtype: int
"""
@@ -129,9 +134,9 @@ class PinaDataset(Dataset):
Return the index itself. This is used when automatic batching is
disabled to postpone the data retrieval to the dataloader.
:param idx: Index
:param idx: Index.
:type idx: int
:return: Index
:return: Index.
:rtype: int
"""
@@ -143,8 +148,8 @@ class PinaDataset(Dataset):
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index
:return: A dictionary containing the data at the given index
:param int idx: Index.
:return: A dictionary containing the data at the given index.
:rtype: dict
"""
@@ -156,23 +161,25 @@ class PinaDataset(Dataset):
def get_all_data(self):
"""
Return all data in the dataset
Return all data in the dataset.
:return: All data in the dataset
:return: A dictionary containing all the data in the dataset.
:rtype: dict
"""
index = list(range(len(self)))
return self.fetch_from_idx_list(index)
def fetch_from_idx_list(self, idx):
"""
Return data from the dataset given a list of indices
Return data from the dataset given a list of indices.
:param idx: List of indices
:param idx: List of indices.
:type idx: list
:return: Data from the dataset
:return: A dictionary containing the data at the given indices.
:rtype: dict
"""
to_return_dict = {}
for condition, data in self.conditions_dict.items():
# Get the indices for the current condition
@@ -190,30 +197,27 @@ class PinaDataset(Dataset):
@abstractmethod
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param dict data: Dictionary containing the data
:param list idx_list: List of indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
Abstract method to retrieve data from the dataset given a list of
indices.
"""
class PinaTensorDataset(PinaDataset):
"""
Class for the PINA dataset with torch.Tensor data
Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`LabelTensor` data.
"""
# Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
Retrieve data from the dataset given a list of indices.
:param data: Dictionary containing the data
(only torch.Tensor/LableTensor)
(only torch.Tensor/LableTensor).
:type data: dict
:param list(int) idx_list: indices to retrieve
:return: Dictionary containing the data at the given indices
:param list(int) idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
"""
@@ -222,9 +226,9 @@ class PinaTensorDataset(PinaDataset):
@property
def input(self):
"""
Method to return all input points from the dataset.
Return the input data for the dataset.
:return: Dictionary containing the input points
:return: Dictionary containing the input points.
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}
@@ -232,15 +236,17 @@ class PinaTensorDataset(PinaDataset):
class PinaGraphDataset(PinaDataset):
"""
Class for the PINA dataset with torch_geometric.data.Data data
Dataset class for the PINA dataset with :class:`torch_geometric.data.Data`
and :class:`Graph` data.
"""
def _create_graph_batch(self, data):
"""
Create a LabelBatch object from a list of Data objects.
Create a LabelBatch object from a list of
:class:`torch_geometric.data.Data` objects.
:param data: List of Data or Graph objects
:type data: list(Data) | list(Graph)
:param data: List of items to collate in a single batch.
:type data: list(torch_geometric.data.Data) | list(Graph)
:return: LabelBatch object all the graph collated in a single batch
disconnected graphs.
:rtype: LabelBatch
@@ -255,7 +261,7 @@ class PinaGraphDataset(PinaDataset):
:param data: torch.Tensor object of shape (N, ...) where N is the
number of data points.
:type data: torch.Tensor | LabelTensor
:return: reshaped torch.Tensor or LabelTensor object
:return: reshaped torch.Tensor or LabelTensor object.
:rtype: torch.Tensor | LabelTensor
"""
out = data.reshape(-1, *data.shape[2:])
@@ -263,12 +269,13 @@ class PinaGraphDataset(PinaDataset):
def create_batch(self, data):
"""
Create a Batch object from a list of Data objects.
Create a Batch object from a list of :class:`torch_geometric.data.Data`
objects.
:param data: List of Data objects
:param data: List of items to collate in a single batch.
:type data: list
:return: Batch object
:rtype: Batch or PinaBatch
:return: Batch object.
:rtype: Batch | PinaBatch
"""
if isinstance(data[0], Data):
@@ -278,13 +285,14 @@ class PinaGraphDataset(PinaDataset):
# Override _retrive_data method for graph handling
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
Retrieve data from the dataset given a list of indices.
:param dict data: dictionary containing the data
:param list idx_list: list of indices to retrieve
:return: dictionary containing the data at the given indices
:param dict data: Dictionary containing the data.
:param list idx_list: List of indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
"""
# Return the data from the current condition
# If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor