Documentation and docstring graph and data

This commit is contained in:
FilippoOlivo
2025-03-10 15:57:15 +01:00
committed by Nicola Demo
parent 6ce0bafc2b
commit 635e3b3a75
3 changed files with 342 additions and 83 deletions

View File

@@ -10,13 +10,31 @@ from ..graph import Graph, LabelBatch
class PinaDatasetFactory:
"""
Factory class for the PINA dataset. Depending on the type inside the
conditions it creates a different dataset object:
- PinaTensorDataset for torch.Tensor
- PinaGraphDataset for list of torch_geometric.data.Data objects
Factory class for the PINA dataset.
Depending on the type inside the conditions, it creates a different dataset
object:
- :class:`PinaTensorDataset` for `torch.Tensor`
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
objects
"""
def __new__(cls, conditions_dict, **kwargs):
"""
Instantiate the appropriate subclass of :class:`PinaDataset`.
If a graph is present in the conditions, returns a
:class:`PinaGraphDataset`, otherwise returns a
:class:`PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing the conditions.
:return: A subclass of :class:`PinaDataset`.
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
:raises ValueError: If an empty dictionary is provided.
"""
# Check if conditions_dict is empty
if len(conditions_dict) == 0:
raise ValueError("No conditions provided")
@@ -31,9 +49,21 @@ class PinaDatasetFactory:
@staticmethod
def _is_graph_dataset(conditions_dict):
"""
Check if a graph is present in the conditions.
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise
:rtype: bool
"""
# Iterate over the conditions dictionary
for v in conditions_dict.values():
# Iterate over the values of the current condition
for cond in v.values():
if isinstance(cond, (Data, Graph, list)):
# Check if the current value is a list of Data objects
if isinstance(cond, (Data, Graph, list, tuple)):
return True
return False
@@ -46,6 +76,19 @@ class PinaDataset(Dataset):
def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching
):
"""
Initialize the :class:`PinaDataset`.
Stores the conditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag.
:param dict conditions_dict: Dictionary containing the conditions.
:param dict max_conditions_lengths: Maximum number of data points to
consider in a single batch for each condition.
:param bool automatic_batching: Whether PyTorch automatic batching is
enabled in :class:`PinaDataModule`.
"""
# Store the conditions dictionary
self.conditions_dict = conditions_dict
# Store the maximum number of conditions to consider
@@ -63,7 +106,13 @@ class PinaDataset(Dataset):
self._getitem_func = self._getitem_dummy
def _get_max_len(self):
""""""
"""
Returns the length of the longest condition in the dataset
:return: Length of the longest condition in the dataset
:rtype: int
"""
max_len = 0
for condition in self.conditions_dict.values():
max_len = max(max_len, len(condition["input"]))
@@ -76,10 +125,29 @@ class PinaDataset(Dataset):
return self._getitem_func(idx)
def _getitem_dummy(self, idx):
"""
Return the index itself. This is used when automatic batching is
disabled to postpone the data retrieval to the dataloader.
:param idx: Index
:type idx: int
:return: Index
:rtype: int
"""
# If automatic batching is disabled, return the data at the given index
return idx
def _getitem_int(self, idx):
"""
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index
:return: A dictionary containing the data at the given index
:rtype: dict
"""
# If automatic batching is enabled, return the data at the given index
return {
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
@@ -121,7 +189,14 @@ class PinaDataset(Dataset):
@abstractmethod
def _retrive_data(self, data, idx_list):
pass
"""
Retrieve data from the dataset given a list of indices
:param dict data: Dictionary containing the data
:param list idx_list: List of indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
"""
class PinaTensorDataset(PinaDataset):
@@ -131,12 +206,26 @@ class PinaTensorDataset(PinaDataset):
# Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param data: Dictionary containing the data
(only torch.Tensor/LableTensor)
:type data: dict
:param list(int) idx_list: indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
"""
return {k: v[idx_list] for k, v in data.items()}
@property
def input(self):
"""
Method to return input points for training.
Method to return all input points from the dataset.
:return: Dictionary containing the input points
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}
@@ -146,15 +235,33 @@ class PinaGraphDataset(PinaDataset):
Class for the PINA dataset with torch_geometric.data.Data data
"""
def _create_graph_batch_from_list(self, data):
def _create_graph_batch(self, data):
"""
Create a LabelBatch object from a list of Data objects.
:param data: List of Data or Graph objects
:type data: list(Data) | list(Graph)
:return: LabelBatch object all the graph collated in a single batch
disconnected graphs.
:rtype: LabelBatch
"""
batch = LabelBatch.from_data_list(data)
return batch
def _create_output_batch(self, data):
def _create_tensor_batch(self, data):
"""
Create a torch.Tensor object from a list of torch.Tensor objects.
:param data: torch.Tensor object of shape (N, ...) where N is the
number of data points.
:type data: torch.Tensor | LabelTensor
:return: reshaped torch.Tensor or LabelTensor object
:rtype: torch.Tensor | LabelTensor
"""
out = data.reshape(-1, *data.shape[2:])
return out
def create_graph_batch(self, data):
def create_batch(self, data):
"""
Create a Batch object from a list of Data objects.
@@ -163,20 +270,29 @@ class PinaGraphDataset(PinaDataset):
:return: Batch object
:rtype: Batch or PinaBatch
"""
if isinstance(data[0], Data):
return self._create_graph_batch_from_list(data)
return self._create_output_batch(data)
return self._create_graph_batch(data)
return self._create_tensor_batch(data)
# Override _retrive_data method for graph handling
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param dict data: dictionary containing the data
:param list idx_list: list of indices to retrieve
:return: dictionary containing the data at the given indices
:rtype: dict
"""
# Return the data from the current condition
# If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor
return {
k: (
self._create_graph_batch_from_list([v[i] for i in idx_list])
self._create_graph_batch([v[i] for i in idx_list])
if isinstance(v, list)
else self._create_output_batch(v[idx_list])
else self._create_tensor_batch(v[idx_list])
)
for k, v in data.items()
}