Documentation and docstring graph and data

This commit is contained in:
FilippoOlivo
2025-03-10 15:57:15 +01:00
committed by Nicola Demo
parent 6ce0bafc2b
commit 635e3b3a75
3 changed files with 342 additions and 83 deletions

View File

@@ -23,16 +23,24 @@ class DummyDataloader:
def __init__(self, dataset): def __init__(self, dataset):
""" """
param dataset: The dataset object to be processed. Preprare a dataloader object which will return the entire dataset
:notes: in a single batch. Depending on the number of GPUs, the dataset we
- **Distributed Environment**: have the following cases:
- Divides the dataset across processes using the
rank and world size. - **Distributed Environment** (multiple GPUs):
- Fetches only the portion of data corresponding to - Divides the dataset across processes using the rank and world
the current process. size.
- **Non-Distributed Environment**: - Fetches only the portion of data corresponding to the current
- Fetches the entire dataset. process.
- **Non-Distributed Environment** (single GPU):
- Fetches the entire dataset.
:param dataset: The dataset object to be processed.
:type dataset: PinaDataset
.. note:: This data loader is used when the batch size is None.
""" """
if ( if (
torch.distributed.is_available() torch.distributed.is_available()
and torch.distributed.is_initialized() and torch.distributed.is_initialized()
@@ -67,23 +75,50 @@ class Collator:
Class used to collate the batch Class used to collate the batch
""" """
def __init__(self, max_conditions_lengths, dataset=None): def __init__(
self, max_conditions_lengths, automatic_batching, dataset=None
):
"""
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: dict containing the maximum number
of data points to consider in a single batch for each condition.
:param PinaDataset dataset: The dataset where the data is stored.
"""
self.max_conditions_lengths = max_conditions_lengths self.max_conditions_lengths = max_conditions_lengths
# Set the collate function based on the batching strategy
# collate_pina_dataloader is used when automatic batching is disabled
# collate_torch_dataloader is used when automatic batching is enabled
self.callable_function = ( self.callable_function = (
self._collate_custom_dataloader self._collate_torch_dataloader
if max_conditions_lengths is None if automatic_batching
else (self._collate_standard_dataloader) else (self._collate_pina_dataloader)
) )
self.dataset = dataset self.dataset = dataset
# Set the function which performs the actual collation
if isinstance(self.dataset, PinaTensorDataset): if isinstance(self.dataset, PinaTensorDataset):
# If the dataset is a PinaTensorDataset, use this collate function
self._collate = self._collate_tensor_dataset self._collate = self._collate_tensor_dataset
else: else:
# If the dataset is a PinaDataset, use this collate function
self._collate = self._collate_graph_dataset self._collate = self._collate_graph_dataset
def _collate_custom_dataloader(self, batch): def _collate_pina_dataloader(self, batch):
"""
Function used to create a batch when automatic batching is disabled.
:param list(int) batch: List of integers representing the indices of
the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset.
:rtype: dict
"""
# Call the fetch_from_idx_list method of the dataset
return self.dataset.fetch_from_idx_list(batch) return self.dataset.fetch_from_idx_list(batch)
def _collate_standard_dataloader(self, batch): def _collate_torch_dataloader(self, batch):
""" """
Function used to collate the batch Function used to collate the batch
""" """
@@ -112,6 +147,19 @@ class Collator:
@staticmethod @staticmethod
def _collate_tensor_dataset(data_list): def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
`PinaTensorDataset`.
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
collated.
:type data_list: list(torch.Tensor) | list(LabelTensor)
:raises RuntimeError: If the data is not a `torch.Tensor` or a
`LabelTensor`.
:return: Batch of data
:rtype: dict
"""
if isinstance(data_list[0], LabelTensor): if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list) return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor): if isinstance(data_list[0], torch.Tensor):
@@ -119,15 +167,36 @@ class Collator:
raise RuntimeError("Data must be Tensors or LabelTensor ") raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list): def _collate_graph_dataset(self, data_list):
"""
Function used to collate the data when the dataset is a
`PinaGraphDataset`.
:param data_list: List of `Data` or `Graph` to be collated.
:type data_list: list(Data) | list(Graph)
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
:return: Batch of data
:rtype: dict
"""
if isinstance(data_list[0], LabelTensor): if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list) return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor): if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list) return torch.cat(data_list)
if isinstance(data_list[0], Data): if isinstance(data_list[0], Data):
return self.dataset.create_graph_batch(data_list) return self.dataset.create_batch(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data") raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
def __call__(self, batch): def __call__(self, batch):
"""
Call the function to collate the batch, defined in __init__.
:param batch: list of indices or list of retrieved data
:type batch: list(int) | list(dict)
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
return self.callable_function(batch) return self.callable_function(batch)
@@ -137,6 +206,16 @@ class PinaSampler:
""" """
def __new__(cls, dataset, shuffle): def __new__(cls, dataset, shuffle):
"""
Create the sampler instance, according to shuffle and whether the
environment is distributed or not.
:param PinaDataset dataset: The dataset to be sampled.
:param bool shuffle: whether to shuffle the dataset or not before
sampling.
:return: The sampler instance.
:rtype: torch.utils.data.Sampler
"""
if ( if (
torch.distributed.is_available() torch.distributed.is_available()
@@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule):
""" """
Initialize the object, creating datasets based on the input problem. Initialize the object, creating datasets based on the input problem.
:param problem: The problem defining the dataset. :param AbstractProblem problem: The problem containing the data on which
:type problem: AbstractProblem to train/test the model.
:param train_size: Fraction or number of elements in the training split. :param float train_size: Fraction or number of elements in the training
:type train_size: float split.
:param test_size: Fraction or number of elements in the test split. :param float test_size: Fraction or number of elements in the test
:type test_size: float split.
:param val_size: Fraction or number of elements in the validation split. :param float val_size: Fraction or number of elements in the validation
:type val_size: float split.
:param batch_size: Batch size used for training. If None, the entire :param batch_size: The batch size used for training. If `None`, the
dataset is used per batch. entire dataset is used per batch.
:type batch_size: int or None :type batch_size: int | None
:param shuffle: Whether to shuffle the dataset before splitting. :param bool shuffle: Whether to shuffle the dataset before splitting.
:type shuffle: bool :param bool repeat: Whether to repeat the dataset indefinitely.
:param repeat: Whether to repeat the dataset indefinitely.
:type repeat: bool
:param automatic_batching: Whether to enable automatic batching. :param automatic_batching: Whether to enable automatic batching.
:type automatic_batching: bool :param int num_workers: Number of worker threads for data loading.
:param num_workers: Number of worker threads for data loading.
Default 0 (serial loading) Default 0 (serial loading)
:type num_workers: int :param bool pin_memory: Whether to use pinned memory for faster data
:param pin_memory: Whether to use pinned memory for faster data
transfer to GPU. (Default False) transfer to GPU. (Default False)
:type pin_memory: bool
""" """
super().__init__() super().__init__()
@@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule):
sampler = PinaSampler(dataset, shuffle) sampler = PinaSampler(dataset, shuffle)
if self.automatic_batching: if self.automatic_batching:
collate = Collator( collate = Collator(
self.find_max_conditions_lengths(split), dataset=dataset self.find_max_conditions_lengths(split),
self.automatic_batching,
dataset=dataset,
) )
else: else:
collate = Collator(None, dataset=dataset) collate = Collator(
None, self.automatic_batching, dataset=dataset
)
return DataLoader( return DataLoader(
dataset, dataset,
self.batch_size, self.batch_size,
@@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule):
def train_dataloader(self): def train_dataloader(self):
""" """
Create the training dataloader Create the training dataloader
:return: The training dataloader
:rtype: DataLoader
""" """
return self._create_dataloader("train", self.train_dataset) return self._create_dataloader("train", self.train_dataset)
def test_dataloader(self): def test_dataloader(self):
""" """
Create the testing dataloader Create the testing dataloader
:return: The testing dataloader
:rtype: DataLoader
""" """
return self._create_dataloader("test", self.test_dataset) return self._create_dataloader("test", self.test_dataset)
@staticmethod @staticmethod
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
This method is used when the batch size is None: batch has already
been transferred to the device.
:param list(tuple) batch: list of tuple where the first element of the
tuple is the condition name and the second element is the data.
:param device: device to which the batch is transferred
:type device: torch.device
:param dataloader_idx: index of the dataloader
:type dataloader_idx: int
:return: The batch transferred to the device.
:rtype: list(tuple)
"""
return batch return batch
def _transfer_batch_to_device(self, batch, device, dataloader_idx): def _transfer_batch_to_device(self, batch, device, dataloader_idx):
""" """
Transfer the batch to the device. This method is called in the Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device. training loop and is used to transfer the batch to the device.
:param dict batch: The batch to be transferred to the device.
:param device: The device to which the batch is transferred.
:type device: torch.device
:param int dataloader_idx: The index of the dataloader.
:return: The batch transferred to the device.
:rtype: list(tuple)
""" """
batch = [ batch = [
( (
@@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule):
@property @property
def input(self): def input(self):
""" """
# TODO Return all the input points coming from all the datasets.
:return: The input points for training.
:rtype dict
""" """
to_return = {} to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None: if hasattr(self, "train_dataset") and self.train_dataset is not None:
@@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule):
if hasattr(self, "val_dataset") and self.val_dataset is not None: if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = self.val_dataset.input to_return["val"] = self.val_dataset.input
if hasattr(self, "test_dataset") and self.test_dataset is not None: if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return = self.test_dataset.input to_return["test"] = self.test_dataset.input
return to_return return to_return

View File

@@ -10,13 +10,31 @@ from ..graph import Graph, LabelBatch
class PinaDatasetFactory: class PinaDatasetFactory:
""" """
Factory class for the PINA dataset. Depending on the type inside the Factory class for the PINA dataset.
conditions it creates a different dataset object:
- PinaTensorDataset for torch.Tensor Depending on the type inside the conditions, it creates a different dataset
- PinaGraphDataset for list of torch_geometric.data.Data objects object:
- :class:`PinaTensorDataset` for `torch.Tensor`
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
objects
""" """
def __new__(cls, conditions_dict, **kwargs): def __new__(cls, conditions_dict, **kwargs):
"""
Instantiate the appropriate subclass of :class:`PinaDataset`.
If a graph is present in the conditions, returns a
:class:`PinaGraphDataset`, otherwise returns a
:class:`PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing the conditions.
:return: A subclass of :class:`PinaDataset`.
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
:raises ValueError: If an empty dictionary is provided.
"""
# Check if conditions_dict is empty # Check if conditions_dict is empty
if len(conditions_dict) == 0: if len(conditions_dict) == 0:
raise ValueError("No conditions provided") raise ValueError("No conditions provided")
@@ -31,9 +49,21 @@ class PinaDatasetFactory:
@staticmethod @staticmethod
def _is_graph_dataset(conditions_dict): def _is_graph_dataset(conditions_dict):
"""
Check if a graph is present in the conditions.
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise
:rtype: bool
"""
# Iterate over the conditions dictionary
for v in conditions_dict.values(): for v in conditions_dict.values():
# Iterate over the values of the current condition
for cond in v.values(): for cond in v.values():
if isinstance(cond, (Data, Graph, list)): # Check if the current value is a list of Data objects
if isinstance(cond, (Data, Graph, list, tuple)):
return True return True
return False return False
@@ -46,6 +76,19 @@ class PinaDataset(Dataset):
def __init__( def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching self, conditions_dict, max_conditions_lengths, automatic_batching
): ):
"""
Initialize the :class:`PinaDataset`.
Stores the conditions dictionary, the maximum number of conditions to
consider, and the automatic batching flag.
:param dict conditions_dict: Dictionary containing the conditions.
:param dict max_conditions_lengths: Maximum number of data points to
consider in a single batch for each condition.
:param bool automatic_batching: Whether PyTorch automatic batching is
enabled in :class:`PinaDataModule`.
"""
# Store the conditions dictionary # Store the conditions dictionary
self.conditions_dict = conditions_dict self.conditions_dict = conditions_dict
# Store the maximum number of conditions to consider # Store the maximum number of conditions to consider
@@ -63,7 +106,13 @@ class PinaDataset(Dataset):
self._getitem_func = self._getitem_dummy self._getitem_func = self._getitem_dummy
def _get_max_len(self): def _get_max_len(self):
"""""" """
Returns the length of the longest condition in the dataset
:return: Length of the longest condition in the dataset
:rtype: int
"""
max_len = 0 max_len = 0
for condition in self.conditions_dict.values(): for condition in self.conditions_dict.values():
max_len = max(max_len, len(condition["input"])) max_len = max(max_len, len(condition["input"]))
@@ -76,10 +125,29 @@ class PinaDataset(Dataset):
return self._getitem_func(idx) return self._getitem_func(idx)
def _getitem_dummy(self, idx): def _getitem_dummy(self, idx):
"""
Return the index itself. This is used when automatic batching is
disabled to postpone the data retrieval to the dataloader.
:param idx: Index
:type idx: int
:return: Index
:rtype: int
"""
# If automatic batching is disabled, return the data at the given index # If automatic batching is disabled, return the data at the given index
return idx return idx
def _getitem_int(self, idx): def _getitem_int(self, idx):
"""
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index
:return: A dictionary containing the data at the given index
:rtype: dict
"""
# If automatic batching is enabled, return the data at the given index # If automatic batching is enabled, return the data at the given index
return { return {
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()} k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
@@ -121,7 +189,14 @@ class PinaDataset(Dataset):
@abstractmethod @abstractmethod
def _retrive_data(self, data, idx_list): def _retrive_data(self, data, idx_list):
pass """
Retrieve data from the dataset given a list of indices
:param dict data: Dictionary containing the data
:param list idx_list: List of indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
"""
class PinaTensorDataset(PinaDataset): class PinaTensorDataset(PinaDataset):
@@ -131,12 +206,26 @@ class PinaTensorDataset(PinaDataset):
# Override _retrive_data method for torch.Tensor data # Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list): def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param data: Dictionary containing the data
(only torch.Tensor/LableTensor)
:type data: dict
:param list(int) idx_list: indices to retrieve
:return: Dictionary containing the data at the given indices
:rtype: dict
"""
return {k: v[idx_list] for k, v in data.items()} return {k: v[idx_list] for k, v in data.items()}
@property @property
def input(self): def input(self):
""" """
Method to return input points for training. Method to return all input points from the dataset.
:return: Dictionary containing the input points
:rtype: dict
""" """
return {k: v["input"] for k, v in self.conditions_dict.items()} return {k: v["input"] for k, v in self.conditions_dict.items()}
@@ -146,15 +235,33 @@ class PinaGraphDataset(PinaDataset):
Class for the PINA dataset with torch_geometric.data.Data data Class for the PINA dataset with torch_geometric.data.Data data
""" """
def _create_graph_batch_from_list(self, data): def _create_graph_batch(self, data):
"""
Create a LabelBatch object from a list of Data objects.
:param data: List of Data or Graph objects
:type data: list(Data) | list(Graph)
:return: LabelBatch object all the graph collated in a single batch
disconnected graphs.
:rtype: LabelBatch
"""
batch = LabelBatch.from_data_list(data) batch = LabelBatch.from_data_list(data)
return batch return batch
def _create_output_batch(self, data): def _create_tensor_batch(self, data):
"""
Create a torch.Tensor object from a list of torch.Tensor objects.
:param data: torch.Tensor object of shape (N, ...) where N is the
number of data points.
:type data: torch.Tensor | LabelTensor
:return: reshaped torch.Tensor or LabelTensor object
:rtype: torch.Tensor | LabelTensor
"""
out = data.reshape(-1, *data.shape[2:]) out = data.reshape(-1, *data.shape[2:])
return out return out
def create_graph_batch(self, data): def create_batch(self, data):
""" """
Create a Batch object from a list of Data objects. Create a Batch object from a list of Data objects.
@@ -163,20 +270,29 @@ class PinaGraphDataset(PinaDataset):
:return: Batch object :return: Batch object
:rtype: Batch or PinaBatch :rtype: Batch or PinaBatch
""" """
if isinstance(data[0], Data): if isinstance(data[0], Data):
return self._create_graph_batch_from_list(data) return self._create_graph_batch(data)
return self._create_output_batch(data) return self._create_tensor_batch(data)
# Override _retrive_data method for graph handling # Override _retrive_data method for graph handling
def _retrive_data(self, data, idx_list): def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices
:param dict data: dictionary containing the data
:param list idx_list: list of indices to retrieve
:return: dictionary containing the data at the given indices
:rtype: dict
"""
# Return the data from the current condition # Return the data from the current condition
# If the data is a list of Data objects, create a Batch object # If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor # If the data is a list of torch.Tensor objects, create a torch.Tensor
return { return {
k: ( k: (
self._create_graph_batch_from_list([v[i] for i in idx_list]) self._create_graph_batch([v[i] for i in idx_list])
if isinstance(v, list) if isinstance(v, list)
else self._create_output_batch(v[idx_list]) else self._create_tensor_batch(v[idx_list])
) )
for k, v in data.items() for k, v in data.items()
} }

View File

@@ -19,6 +19,9 @@ class Graph(Data):
**kwargs, **kwargs,
): ):
""" """
Instantiates a new instance of the Graph class, performing type
consistency checks.
:param kwargs: Parameters to construct the Graph object. :param kwargs: Parameters to construct the Graph object.
:return: A new instance of the Graph class. :return: A new instance of the Graph class.
:rtype: Graph :rtype: Graph
@@ -42,7 +45,10 @@ class Graph(Data):
**kwargs, **kwargs,
): ):
""" """
Initialize the Graph object. Initialize the Graph object by setting the node features, edge index,
edge attributes, and positions. The edge index is preprocessed to make
the graph undirected if required. For more details, see the
:meth: `torch_geometric.data.Data`
:param x: Optional tensor of node features (N, F) where F is the number :param x: Optional tensor of node features (N, F) where F is the number
of features per node. of features per node.
@@ -69,6 +75,13 @@ class Graph(Data):
) )
def _check_type_consistency(self, **kwargs): def _check_type_consistency(self, **kwargs):
"""
Check the consistency of the types of the input data.
:param kwargs: Attributes to be checked for consistency.
:type kwargs: dict
"""
# default types, specified in cls.__new__, by default they are Nont # default types, specified in cls.__new__, by default they are Nont
# if specified in **kwargs they get override # if specified in **kwargs they get override
x, pos, edge_index, edge_attr = None, None, None, None x, pos, edge_index, edge_attr = None, None, None, None
@@ -92,8 +105,10 @@ class Graph(Data):
def _check_pos_consistency(pos): def _check_pos_consistency(pos):
""" """
Check if the position tensor is consistent. Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor. :param torch.Tensor pos: The position tensor.
""" """
if pos is not None: if pos is not None:
check_consistency(pos, (torch.Tensor, LabelTensor)) check_consistency(pos, (torch.Tensor, LabelTensor))
if pos.ndim != 2: if pos.ndim != 2:
@@ -103,8 +118,10 @@ class Graph(Data):
def _check_edge_index_consistency(edge_index): def _check_edge_index_consistency(edge_index):
""" """
Check if the edge index is consistent. Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor. :param torch.Tensor edge_index: The edge index tensor.
""" """
check_consistency(edge_index, (torch.Tensor, LabelTensor)) check_consistency(edge_index, (torch.Tensor, LabelTensor))
if edge_index.ndim != 2: if edge_index.ndim != 2:
raise ValueError("edge_index must be a 2D tensor.") raise ValueError("edge_index must be a 2D tensor.")
@@ -114,11 +131,13 @@ class Graph(Data):
@staticmethod @staticmethod
def _check_edge_attr_consistency(edge_attr, edge_index): def _check_edge_attr_consistency(edge_attr, edge_index):
""" """
Check if the edge attr is consistent. Check if the edge attribute tensor is consistent in type and shape
:param torch.Tensor edge_attr: The edge attribute tensor. with the edge index.
:param torch.Tensor edge_attr: The edge attribute tensor.
:param torch.Tensor edge_index: The edge index tensor. :param torch.Tensor edge_index: The edge index tensor.
""" """
if edge_attr is not None: if edge_attr is not None:
check_consistency(edge_attr, (torch.Tensor, LabelTensor)) check_consistency(edge_attr, (torch.Tensor, LabelTensor))
if edge_attr.ndim != 2: if edge_attr.ndim != 2:
@@ -134,10 +153,13 @@ class Graph(Data):
@staticmethod @staticmethod
def _check_x_consistency(x, pos=None): def _check_x_consistency(x, pos=None):
""" """
Check if the input tensor x is consistent with the position tensor pos. Check if the input tensor x is consistent with the position tensor
`pos`.
:param torch.Tensor x: The input tensor. :param torch.Tensor x: The input tensor.
:param torch.Tensor pos: The position tensor. :param torch.Tensor pos: The position tensor.
""" """
if x is not None: if x is not None:
check_consistency(x, (torch.Tensor, LabelTensor)) check_consistency(x, (torch.Tensor, LabelTensor))
if x.ndim != 2: if x.ndim != 2:
@@ -152,22 +174,24 @@ class Graph(Data):
@staticmethod @staticmethod
def _preprocess_edge_index(edge_index, undirected): def _preprocess_edge_index(edge_index, undirected):
""" """
Preprocess the edge index. Preprocess the edge index to make the graph undirected (if required).
:param torch.Tensor edge_index: The edge index. :param torch.Tensor edge_index: The edge index.
:param bool undirected: Whether the graph is undirected. :param bool undirected: Whether the graph is undirected.
:return: The preprocessed edge index. :return: The preprocessed edge index.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
if undirected: if undirected:
edge_index = to_undirected(edge_index) edge_index = to_undirected(edge_index)
return edge_index return edge_index
def extract(self, labels, attr="x"): def extract(self, labels, attr="x"):
""" """
Perform extraction of labels on node features (x) Perform extraction of labels from the attribute specified by `attr`.
:param labels: Labels to extract :param labels: Labels to extract
:type labels: list[str] | tuple[str] | str :type labels: list[str] | tuple[str] | str | dict
:return: Batch object with extraction performed on x :return: Batch object with extraction performed on x
:rtype: PinaBatch :rtype: PinaBatch
""" """
@@ -193,21 +217,24 @@ class GraphBuilder:
**kwargs, **kwargs,
): ):
""" """
Creates a new instance of the Graph class. Compute the edge attributes and create a new instance of the Graph
class.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space. points in D-dimensional space.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape (2, E) representing the indices of :param edge_index: A tensor of shape (2, E) representing the indices of
the graph's edges. the graph's edges.
:type edge_index: torch.Tensor :type edge_index: torch.Tensor
:param x: Optional tensor of node features (N, F) where F is the number :param x: Optional tensor of node features of shape (N, F), where F is
of features per node. the number of features per node.
:type x: torch.Tensor, LabelTensor :type x: torch.Tensor | LabelTensor, optional
:param bool edge_attr: Optional edge attributes (E, F) where F is the :param edge_attr: Optional tensor of edge attributes of shape (E, F),
number of features per edge. where F is the number of features per edge.
:param callable custom_edge_func: A custom function to compute edge :type edge_attr: torch.Tensor, optional
attributes. :param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides `edge_attr`.
:type custom_edge_func: callable, optional
:param kwargs: Additional keyword arguments passed to the Graph class :param kwargs: Additional keyword arguments passed to the Graph class
constructor. constructor.
:return: A Graph instance constructed using the provided information. :return: A Graph instance constructed using the provided information.
@@ -249,18 +276,18 @@ class RadiusGraph(GraphBuilder):
def __new__(cls, pos, radius, **kwargs): def __new__(cls, pos, radius, **kwargs):
""" """
Creates a new instance of the Graph class using a radius-based graph Extends the `GraphBuilder` class to compute edge_index based on a
construction. radius. Each point is connected to all the points within the radius.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space. points in D-dimensional space.
:type pos: torch.Tensor | LabelTensor :type pos: torch.Tensor or LabelTensor
:param float radius: The radius within which points are connected. :param radius: The radius within which points are connected.
:Keyword Arguments: :type radius: float
The additional keyword arguments to be passed to GraphBuilder :param kwargs: Additional keyword arguments to be passed to the
and Graph classes `GraphBuilder` and `Graph` constructors.
:return: Graph instance containg the information passed in input and :return: A `Graph` instance containing the input information and the
the computed edge_index computed edge_index.
:rtype: Graph :rtype: Graph
""" """
edge_index = cls.compute_radius_graph(pos, radius) edge_index = cls.compute_radius_graph(pos, radius)
@@ -269,7 +296,8 @@ class RadiusGraph(GraphBuilder):
@staticmethod @staticmethod
def compute_radius_graph(points, radius): def compute_radius_graph(points, radius):
""" """
Computes a radius-based graph for a given set of points. Computes edge_index for a given set of points base on the radius.
Each point is connected to all the points within the radius.
:param points: A tensor of shape (N, D) representing the positions of :param points: A tensor of shape (N, D) representing the positions of
N points in D-dimensional space. N points in D-dimensional space.
@@ -295,7 +323,7 @@ class KNNGraph(GraphBuilder):
def __new__(cls, pos, neighbours, **kwargs): def __new__(cls, pos, neighbours, **kwargs):
""" """
Creates a new instance of the Graph class using k-nearest neighbors Creates a new instance of the Graph class using k-nearest neighbors
to compute edge_index. algorithm to define the edges.
:param pos: A tensor of shape (N, D) representing the positions of N :param pos: A tensor of shape (N, D) representing the positions of N
points in D-dimensional space. points in D-dimensional space.
@@ -323,8 +351,9 @@ class KNNGraph(GraphBuilder):
N points in D-dimensional space. N points in D-dimensional space.
:type points: torch.Tensor | LabelTensor :type points: torch.Tensor | LabelTensor
:param int k: The number of nearest neighbors to find for each point. :param int k: The number of nearest neighbors to find for each point.
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of :return: A tensor of shape (2, E), where E is the number of
edges, representing the edge indices of the KNN graph. edges, representing the edge indices of the KNN graph.
:rtype: torch.Tensor
""" """
dist = torch.cdist(points, points, p=2) dist = torch.cdist(points, points, p=2)
@@ -343,6 +372,11 @@ class LabelBatch(Batch):
def from_data_list(cls, data_list): def from_data_list(cls, data_list):
""" """
Create a Batch object from a list of Data objects. Create a Batch object from a list of Data objects.
:param data_list: List of Data/Graph objects
:type data_list: list[Data] | list[Graph]
:return: A Batch object containing the data in the list
:rtype: Batch
""" """
# Store the labels of Data/Graph objects (all data have the same labels) # Store the labels of Data/Graph objects (all data have the same labels)
# If the data do not contain labels, labels is an empty dictionary, # If the data do not contain labels, labels is an empty dictionary,