Documentation and docstring graph and data
This commit is contained in:
committed by
Nicola Demo
parent
6ce0bafc2b
commit
635e3b3a75
@@ -23,16 +23,24 @@ class DummyDataloader:
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
param dataset: The dataset object to be processed.
|
||||
:notes:
|
||||
- **Distributed Environment**:
|
||||
- Divides the dataset across processes using the
|
||||
rank and world size.
|
||||
- Fetches only the portion of data corresponding to
|
||||
the current process.
|
||||
- **Non-Distributed Environment**:
|
||||
- Fetches the entire dataset.
|
||||
Preprare a dataloader object which will return the entire dataset
|
||||
in a single batch. Depending on the number of GPUs, the dataset we
|
||||
have the following cases:
|
||||
|
||||
- **Distributed Environment** (multiple GPUs):
|
||||
- Divides the dataset across processes using the rank and world
|
||||
size.
|
||||
- Fetches only the portion of data corresponding to the current
|
||||
process.
|
||||
- **Non-Distributed Environment** (single GPU):
|
||||
- Fetches the entire dataset.
|
||||
|
||||
:param dataset: The dataset object to be processed.
|
||||
:type dataset: PinaDataset
|
||||
|
||||
.. note:: This data loader is used when the batch size is None.
|
||||
"""
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
@@ -67,23 +75,50 @@ class Collator:
|
||||
Class used to collate the batch
|
||||
"""
|
||||
|
||||
def __init__(self, max_conditions_lengths, dataset=None):
|
||||
def __init__(
|
||||
self, max_conditions_lengths, automatic_batching, dataset=None
|
||||
):
|
||||
"""
|
||||
Initialize the object, setting the collate function based on whether
|
||||
automatic batching is enabled or not.
|
||||
|
||||
:param dict max_conditions_lengths: dict containing the maximum number
|
||||
of data points to consider in a single batch for each condition.
|
||||
:param PinaDataset dataset: The dataset where the data is stored.
|
||||
"""
|
||||
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
# Set the collate function based on the batching strategy
|
||||
# collate_pina_dataloader is used when automatic batching is disabled
|
||||
# collate_torch_dataloader is used when automatic batching is enabled
|
||||
self.callable_function = (
|
||||
self._collate_custom_dataloader
|
||||
if max_conditions_lengths is None
|
||||
else (self._collate_standard_dataloader)
|
||||
self._collate_torch_dataloader
|
||||
if automatic_batching
|
||||
else (self._collate_pina_dataloader)
|
||||
)
|
||||
self.dataset = dataset
|
||||
|
||||
# Set the function which performs the actual collation
|
||||
if isinstance(self.dataset, PinaTensorDataset):
|
||||
# If the dataset is a PinaTensorDataset, use this collate function
|
||||
self._collate = self._collate_tensor_dataset
|
||||
else:
|
||||
# If the dataset is a PinaDataset, use this collate function
|
||||
self._collate = self._collate_graph_dataset
|
||||
|
||||
def _collate_custom_dataloader(self, batch):
|
||||
def _collate_pina_dataloader(self, batch):
|
||||
"""
|
||||
Function used to create a batch when automatic batching is disabled.
|
||||
|
||||
:param list(int) batch: List of integers representing the indices of
|
||||
the data points to be fetched.
|
||||
:return: Dictionary containing the data points fetched from the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
# Call the fetch_from_idx_list method of the dataset
|
||||
return self.dataset.fetch_from_idx_list(batch)
|
||||
|
||||
def _collate_standard_dataloader(self, batch):
|
||||
def _collate_torch_dataloader(self, batch):
|
||||
"""
|
||||
Function used to collate the batch
|
||||
"""
|
||||
@@ -112,6 +147,19 @@ class Collator:
|
||||
|
||||
@staticmethod
|
||||
def _collate_tensor_dataset(data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
`PinaTensorDataset`.
|
||||
|
||||
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
|
||||
collated.
|
||||
:type data_list: list(torch.Tensor) | list(LabelTensor)
|
||||
:raises RuntimeError: If the data is not a `torch.Tensor` or a
|
||||
`LabelTensor`.
|
||||
:return: Batch of data
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.stack(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
@@ -119,15 +167,36 @@ class Collator:
|
||||
raise RuntimeError("Data must be Tensors or LabelTensor ")
|
||||
|
||||
def _collate_graph_dataset(self, data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
`PinaGraphDataset`.
|
||||
|
||||
:param data_list: List of `Data` or `Graph` to be collated.
|
||||
:type data_list: list(Data) | list(Graph)
|
||||
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
|
||||
:return: Batch of data
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.cat(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
return torch.cat(data_list)
|
||||
if isinstance(data_list[0], Data):
|
||||
return self.dataset.create_graph_batch(data_list)
|
||||
return self.dataset.create_batch(data_list)
|
||||
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
|
||||
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Call the function to collate the batch, defined in __init__.
|
||||
|
||||
:param batch: list of indices or list of retrieved data
|
||||
:type batch: list(int) | list(dict)
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return self.callable_function(batch)
|
||||
|
||||
|
||||
@@ -137,6 +206,16 @@ class PinaSampler:
|
||||
"""
|
||||
|
||||
def __new__(cls, dataset, shuffle):
|
||||
"""
|
||||
Create the sampler instance, according to shuffle and whether the
|
||||
environment is distributed or not.
|
||||
|
||||
:param PinaDataset dataset: The dataset to be sampled.
|
||||
:param bool shuffle: whether to shuffle the dataset or not before
|
||||
sampling.
|
||||
:return: The sampler instance.
|
||||
:rtype: torch.utils.data.Sampler
|
||||
"""
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
@@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
Initialize the object, creating datasets based on the input problem.
|
||||
|
||||
:param problem: The problem defining the dataset.
|
||||
:type problem: AbstractProblem
|
||||
:param train_size: Fraction or number of elements in the training split.
|
||||
:type train_size: float
|
||||
:param test_size: Fraction or number of elements in the test split.
|
||||
:type test_size: float
|
||||
:param val_size: Fraction or number of elements in the validation split.
|
||||
:type val_size: float
|
||||
:param batch_size: Batch size used for training. If None, the entire
|
||||
dataset is used per batch.
|
||||
:type batch_size: int or None
|
||||
:param shuffle: Whether to shuffle the dataset before splitting.
|
||||
:type shuffle: bool
|
||||
:param repeat: Whether to repeat the dataset indefinitely.
|
||||
:type repeat: bool
|
||||
:param AbstractProblem problem: The problem containing the data on which
|
||||
to train/test the model.
|
||||
:param float train_size: Fraction or number of elements in the training
|
||||
split.
|
||||
:param float test_size: Fraction or number of elements in the test
|
||||
split.
|
||||
:param float val_size: Fraction or number of elements in the validation
|
||||
split.
|
||||
:param batch_size: The batch size used for training. If `None`, the
|
||||
entire dataset is used per batch.
|
||||
:type batch_size: int | None
|
||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||
:param bool repeat: Whether to repeat the dataset indefinitely.
|
||||
:param automatic_batching: Whether to enable automatic batching.
|
||||
:type automatic_batching: bool
|
||||
:param num_workers: Number of worker threads for data loading.
|
||||
:param int num_workers: Number of worker threads for data loading.
|
||||
Default 0 (serial loading)
|
||||
:type num_workers: int
|
||||
:param pin_memory: Whether to use pinned memory for faster data
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. (Default False)
|
||||
:type pin_memory: bool
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule):
|
||||
sampler = PinaSampler(dataset, shuffle)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(
|
||||
self.find_max_conditions_lengths(split), dataset=dataset
|
||||
self.find_max_conditions_lengths(split),
|
||||
self.automatic_batching,
|
||||
dataset=dataset,
|
||||
)
|
||||
else:
|
||||
collate = Collator(None, dataset=dataset)
|
||||
collate = Collator(
|
||||
None, self.automatic_batching, dataset=dataset
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
@@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule):
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Create the training dataloader
|
||||
|
||||
:return: The training dataloader
|
||||
:rtype: DataLoader
|
||||
"""
|
||||
return self._create_dataloader("train", self.train_dataset)
|
||||
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
Create the testing dataloader
|
||||
|
||||
:return: The testing dataloader
|
||||
:rtype: DataLoader
|
||||
"""
|
||||
return self._create_dataloader("test", self.test_dataset)
|
||||
|
||||
@staticmethod
|
||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
This method is used when the batch size is None: batch has already
|
||||
been transferred to the device.
|
||||
|
||||
:param list(tuple) batch: list of tuple where the first element of the
|
||||
tuple is the condition name and the second element is the data.
|
||||
:param device: device to which the batch is transferred
|
||||
:type device: torch.device
|
||||
:param dataloader_idx: index of the dataloader
|
||||
:type dataloader_idx: int
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
"""
|
||||
return batch
|
||||
|
||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
|
||||
:param dict batch: The batch to be transferred to the device.
|
||||
:param device: The device to which the batch is transferred.
|
||||
:type device: torch.device
|
||||
:param int dataloader_idx: The index of the dataloader.
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
"""
|
||||
batch = [
|
||||
(
|
||||
@@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule):
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
# TODO
|
||||
Return all the input points coming from all the datasets.
|
||||
|
||||
:return: The input points for training.
|
||||
:rtype dict
|
||||
"""
|
||||
to_return = {}
|
||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||
@@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule):
|
||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||
to_return["val"] = self.val_dataset.input
|
||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||
to_return = self.test_dataset.input
|
||||
to_return["test"] = self.test_dataset.input
|
||||
return to_return
|
||||
|
||||
@@ -10,13 +10,31 @@ from ..graph import Graph, LabelBatch
|
||||
|
||||
class PinaDatasetFactory:
|
||||
"""
|
||||
Factory class for the PINA dataset. Depending on the type inside the
|
||||
conditions it creates a different dataset object:
|
||||
- PinaTensorDataset for torch.Tensor
|
||||
- PinaGraphDataset for list of torch_geometric.data.Data objects
|
||||
Factory class for the PINA dataset.
|
||||
|
||||
Depending on the type inside the conditions, it creates a different dataset
|
||||
object:
|
||||
|
||||
- :class:`PinaTensorDataset` for `torch.Tensor`
|
||||
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
|
||||
objects
|
||||
"""
|
||||
|
||||
def __new__(cls, conditions_dict, **kwargs):
|
||||
"""
|
||||
Instantiate the appropriate subclass of :class:`PinaDataset`.
|
||||
|
||||
If a graph is present in the conditions, returns a
|
||||
:class:`PinaGraphDataset`, otherwise returns a
|
||||
:class:`PinaTensorDataset`.
|
||||
|
||||
:param dict conditions_dict: Dictionary containing the conditions.
|
||||
:return: A subclass of :class:`PinaDataset`.
|
||||
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
|
||||
|
||||
:raises ValueError: If an empty dictionary is provided.
|
||||
"""
|
||||
|
||||
# Check if conditions_dict is empty
|
||||
if len(conditions_dict) == 0:
|
||||
raise ValueError("No conditions provided")
|
||||
@@ -31,9 +49,21 @@ class PinaDatasetFactory:
|
||||
|
||||
@staticmethod
|
||||
def _is_graph_dataset(conditions_dict):
|
||||
"""
|
||||
Check if a graph is present in the conditions.
|
||||
|
||||
:param conditions_dict: Dictionary containing the conditions.
|
||||
:type conditions_dict: dict
|
||||
:return: True if a graph is present in the conditions, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
# Iterate over the conditions dictionary
|
||||
for v in conditions_dict.values():
|
||||
# Iterate over the values of the current condition
|
||||
for cond in v.values():
|
||||
if isinstance(cond, (Data, Graph, list)):
|
||||
# Check if the current value is a list of Data objects
|
||||
if isinstance(cond, (Data, Graph, list, tuple)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -46,6 +76,19 @@ class PinaDataset(Dataset):
|
||||
def __init__(
|
||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||
):
|
||||
"""
|
||||
Initialize the :class:`PinaDataset`.
|
||||
|
||||
Stores the conditions dictionary, the maximum number of conditions to
|
||||
consider, and the automatic batching flag.
|
||||
|
||||
:param dict conditions_dict: Dictionary containing the conditions.
|
||||
:param dict max_conditions_lengths: Maximum number of data points to
|
||||
consider in a single batch for each condition.
|
||||
:param bool automatic_batching: Whether PyTorch automatic batching is
|
||||
enabled in :class:`PinaDataModule`.
|
||||
"""
|
||||
|
||||
# Store the conditions dictionary
|
||||
self.conditions_dict = conditions_dict
|
||||
# Store the maximum number of conditions to consider
|
||||
@@ -63,7 +106,13 @@ class PinaDataset(Dataset):
|
||||
self._getitem_func = self._getitem_dummy
|
||||
|
||||
def _get_max_len(self):
|
||||
""""""
|
||||
"""
|
||||
Returns the length of the longest condition in the dataset
|
||||
|
||||
:return: Length of the longest condition in the dataset
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
max_len = 0
|
||||
for condition in self.conditions_dict.values():
|
||||
max_len = max(max_len, len(condition["input"]))
|
||||
@@ -76,10 +125,29 @@ class PinaDataset(Dataset):
|
||||
return self._getitem_func(idx)
|
||||
|
||||
def _getitem_dummy(self, idx):
|
||||
"""
|
||||
Return the index itself. This is used when automatic batching is
|
||||
disabled to postpone the data retrieval to the dataloader.
|
||||
|
||||
:param idx: Index
|
||||
:type idx: int
|
||||
:return: Index
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
# If automatic batching is disabled, return the data at the given index
|
||||
return idx
|
||||
|
||||
def _getitem_int(self, idx):
|
||||
"""
|
||||
Return the data at the given index in the dataset. This is used when
|
||||
automatic batching is enabled.
|
||||
|
||||
:param int idx: Index
|
||||
:return: A dictionary containing the data at the given index
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
# If automatic batching is enabled, return the data at the given index
|
||||
return {
|
||||
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
|
||||
@@ -121,7 +189,14 @@ class PinaDataset(Dataset):
|
||||
|
||||
@abstractmethod
|
||||
def _retrive_data(self, data, idx_list):
|
||||
pass
|
||||
"""
|
||||
Retrieve data from the dataset given a list of indices
|
||||
|
||||
:param dict data: Dictionary containing the data
|
||||
:param list idx_list: List of indices to retrieve
|
||||
:return: Dictionary containing the data at the given indices
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
@@ -131,12 +206,26 @@ class PinaTensorDataset(PinaDataset):
|
||||
|
||||
# Override _retrive_data method for torch.Tensor data
|
||||
def _retrive_data(self, data, idx_list):
|
||||
"""
|
||||
Retrieve data from the dataset given a list of indices
|
||||
|
||||
:param data: Dictionary containing the data
|
||||
(only torch.Tensor/LableTensor)
|
||||
:type data: dict
|
||||
:param list(int) idx_list: indices to retrieve
|
||||
:return: Dictionary containing the data at the given indices
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return {k: v[idx_list] for k, v in data.items()}
|
||||
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
Method to return input points for training.
|
||||
Method to return all input points from the dataset.
|
||||
|
||||
:return: Dictionary containing the input points
|
||||
:rtype: dict
|
||||
"""
|
||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||
|
||||
@@ -146,15 +235,33 @@ class PinaGraphDataset(PinaDataset):
|
||||
Class for the PINA dataset with torch_geometric.data.Data data
|
||||
"""
|
||||
|
||||
def _create_graph_batch_from_list(self, data):
|
||||
def _create_graph_batch(self, data):
|
||||
"""
|
||||
Create a LabelBatch object from a list of Data objects.
|
||||
|
||||
:param data: List of Data or Graph objects
|
||||
:type data: list(Data) | list(Graph)
|
||||
:return: LabelBatch object all the graph collated in a single batch
|
||||
disconnected graphs.
|
||||
:rtype: LabelBatch
|
||||
"""
|
||||
batch = LabelBatch.from_data_list(data)
|
||||
return batch
|
||||
|
||||
def _create_output_batch(self, data):
|
||||
def _create_tensor_batch(self, data):
|
||||
"""
|
||||
Create a torch.Tensor object from a list of torch.Tensor objects.
|
||||
|
||||
:param data: torch.Tensor object of shape (N, ...) where N is the
|
||||
number of data points.
|
||||
:type data: torch.Tensor | LabelTensor
|
||||
:return: reshaped torch.Tensor or LabelTensor object
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
"""
|
||||
out = data.reshape(-1, *data.shape[2:])
|
||||
return out
|
||||
|
||||
def create_graph_batch(self, data):
|
||||
def create_batch(self, data):
|
||||
"""
|
||||
Create a Batch object from a list of Data objects.
|
||||
|
||||
@@ -163,20 +270,29 @@ class PinaGraphDataset(PinaDataset):
|
||||
:return: Batch object
|
||||
:rtype: Batch or PinaBatch
|
||||
"""
|
||||
|
||||
if isinstance(data[0], Data):
|
||||
return self._create_graph_batch_from_list(data)
|
||||
return self._create_output_batch(data)
|
||||
return self._create_graph_batch(data)
|
||||
return self._create_tensor_batch(data)
|
||||
|
||||
# Override _retrive_data method for graph handling
|
||||
def _retrive_data(self, data, idx_list):
|
||||
"""
|
||||
Retrieve data from the dataset given a list of indices
|
||||
|
||||
:param dict data: dictionary containing the data
|
||||
:param list idx_list: list of indices to retrieve
|
||||
:return: dictionary containing the data at the given indices
|
||||
:rtype: dict
|
||||
"""
|
||||
# Return the data from the current condition
|
||||
# If the data is a list of Data objects, create a Batch object
|
||||
# If the data is a list of torch.Tensor objects, create a torch.Tensor
|
||||
return {
|
||||
k: (
|
||||
self._create_graph_batch_from_list([v[i] for i in idx_list])
|
||||
self._create_graph_batch([v[i] for i in idx_list])
|
||||
if isinstance(v, list)
|
||||
else self._create_output_batch(v[idx_list])
|
||||
else self._create_tensor_batch(v[idx_list])
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
@@ -19,6 +19,9 @@ class Graph(Data):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiates a new instance of the Graph class, performing type
|
||||
consistency checks.
|
||||
|
||||
:param kwargs: Parameters to construct the Graph object.
|
||||
:return: A new instance of the Graph class.
|
||||
:rtype: Graph
|
||||
@@ -42,7 +45,10 @@ class Graph(Data):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Graph object.
|
||||
Initialize the Graph object by setting the node features, edge index,
|
||||
edge attributes, and positions. The edge index is preprocessed to make
|
||||
the graph undirected if required. For more details, see the
|
||||
:meth: `torch_geometric.data.Data`
|
||||
|
||||
:param x: Optional tensor of node features (N, F) where F is the number
|
||||
of features per node.
|
||||
@@ -69,6 +75,13 @@ class Graph(Data):
|
||||
)
|
||||
|
||||
def _check_type_consistency(self, **kwargs):
|
||||
"""
|
||||
Check the consistency of the types of the input data.
|
||||
|
||||
:param kwargs: Attributes to be checked for consistency.
|
||||
:type kwargs: dict
|
||||
"""
|
||||
|
||||
# default types, specified in cls.__new__, by default they are Nont
|
||||
# if specified in **kwargs they get override
|
||||
x, pos, edge_index, edge_attr = None, None, None, None
|
||||
@@ -92,8 +105,10 @@ class Graph(Data):
|
||||
def _check_pos_consistency(pos):
|
||||
"""
|
||||
Check if the position tensor is consistent.
|
||||
|
||||
:param torch.Tensor pos: The position tensor.
|
||||
"""
|
||||
|
||||
if pos is not None:
|
||||
check_consistency(pos, (torch.Tensor, LabelTensor))
|
||||
if pos.ndim != 2:
|
||||
@@ -103,8 +118,10 @@ class Graph(Data):
|
||||
def _check_edge_index_consistency(edge_index):
|
||||
"""
|
||||
Check if the edge index is consistent.
|
||||
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
"""
|
||||
|
||||
check_consistency(edge_index, (torch.Tensor, LabelTensor))
|
||||
if edge_index.ndim != 2:
|
||||
raise ValueError("edge_index must be a 2D tensor.")
|
||||
@@ -114,11 +131,13 @@ class Graph(Data):
|
||||
@staticmethod
|
||||
def _check_edge_attr_consistency(edge_attr, edge_index):
|
||||
"""
|
||||
Check if the edge attr is consistent.
|
||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||
Check if the edge attribute tensor is consistent in type and shape
|
||||
with the edge index.
|
||||
|
||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||
:param torch.Tensor edge_index: The edge index tensor.
|
||||
"""
|
||||
|
||||
if edge_attr is not None:
|
||||
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
|
||||
if edge_attr.ndim != 2:
|
||||
@@ -134,10 +153,13 @@ class Graph(Data):
|
||||
@staticmethod
|
||||
def _check_x_consistency(x, pos=None):
|
||||
"""
|
||||
Check if the input tensor x is consistent with the position tensor pos.
|
||||
Check if the input tensor x is consistent with the position tensor
|
||||
`pos`.
|
||||
|
||||
:param torch.Tensor x: The input tensor.
|
||||
:param torch.Tensor pos: The position tensor.
|
||||
"""
|
||||
|
||||
if x is not None:
|
||||
check_consistency(x, (torch.Tensor, LabelTensor))
|
||||
if x.ndim != 2:
|
||||
@@ -152,22 +174,24 @@ class Graph(Data):
|
||||
@staticmethod
|
||||
def _preprocess_edge_index(edge_index, undirected):
|
||||
"""
|
||||
Preprocess the edge index.
|
||||
Preprocess the edge index to make the graph undirected (if required).
|
||||
|
||||
:param torch.Tensor edge_index: The edge index.
|
||||
:param bool undirected: Whether the graph is undirected.
|
||||
:return: The preprocessed edge index.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
if undirected:
|
||||
edge_index = to_undirected(edge_index)
|
||||
return edge_index
|
||||
|
||||
def extract(self, labels, attr="x"):
|
||||
"""
|
||||
Perform extraction of labels on node features (x)
|
||||
Perform extraction of labels from the attribute specified by `attr`.
|
||||
|
||||
:param labels: Labels to extract
|
||||
:type labels: list[str] | tuple[str] | str
|
||||
:type labels: list[str] | tuple[str] | str | dict
|
||||
:return: Batch object with extraction performed on x
|
||||
:rtype: PinaBatch
|
||||
"""
|
||||
@@ -193,21 +217,24 @@ class GraphBuilder:
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a new instance of the Graph class.
|
||||
Compute the edge attributes and create a new instance of the Graph
|
||||
class.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:type pos: torch.Tensor or LabelTensor
|
||||
:param edge_index: A tensor of shape (2, E) representing the indices of
|
||||
the graph's edges.
|
||||
:type edge_index: torch.Tensor
|
||||
:param x: Optional tensor of node features (N, F) where F is the number
|
||||
of features per node.
|
||||
:type x: torch.Tensor, LabelTensor
|
||||
:param bool edge_attr: Optional edge attributes (E, F) where F is the
|
||||
number of features per edge.
|
||||
:param callable custom_edge_func: A custom function to compute edge
|
||||
attributes.
|
||||
:param x: Optional tensor of node features of shape (N, F), where F is
|
||||
the number of features per node.
|
||||
:type x: torch.Tensor | LabelTensor, optional
|
||||
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
|
||||
where F is the number of features per edge.
|
||||
:type edge_attr: torch.Tensor, optional
|
||||
:param custom_edge_func: A custom function to compute edge attributes.
|
||||
If provided, overrides `edge_attr`.
|
||||
:type custom_edge_func: callable, optional
|
||||
:param kwargs: Additional keyword arguments passed to the Graph class
|
||||
constructor.
|
||||
:return: A Graph instance constructed using the provided information.
|
||||
@@ -249,18 +276,18 @@ class RadiusGraph(GraphBuilder):
|
||||
|
||||
def __new__(cls, pos, radius, **kwargs):
|
||||
"""
|
||||
Creates a new instance of the Graph class using a radius-based graph
|
||||
construction.
|
||||
Extends the `GraphBuilder` class to compute edge_index based on a
|
||||
radius. Each point is connected to all the points within the radius.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param float radius: The radius within which points are connected.
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments to be passed to GraphBuilder
|
||||
and Graph classes
|
||||
:return: Graph instance containg the information passed in input and
|
||||
the computed edge_index
|
||||
:type pos: torch.Tensor or LabelTensor
|
||||
:param radius: The radius within which points are connected.
|
||||
:type radius: float
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
`GraphBuilder` and `Graph` constructors.
|
||||
:return: A `Graph` instance containing the input information and the
|
||||
computed edge_index.
|
||||
:rtype: Graph
|
||||
"""
|
||||
edge_index = cls.compute_radius_graph(pos, radius)
|
||||
@@ -269,7 +296,8 @@ class RadiusGraph(GraphBuilder):
|
||||
@staticmethod
|
||||
def compute_radius_graph(points, radius):
|
||||
"""
|
||||
Computes a radius-based graph for a given set of points.
|
||||
Computes edge_index for a given set of points base on the radius.
|
||||
Each point is connected to all the points within the radius.
|
||||
|
||||
:param points: A tensor of shape (N, D) representing the positions of
|
||||
N points in D-dimensional space.
|
||||
@@ -295,7 +323,7 @@ class KNNGraph(GraphBuilder):
|
||||
def __new__(cls, pos, neighbours, **kwargs):
|
||||
"""
|
||||
Creates a new instance of the Graph class using k-nearest neighbors
|
||||
to compute edge_index.
|
||||
algorithm to define the edges.
|
||||
|
||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||
points in D-dimensional space.
|
||||
@@ -323,8 +351,9 @@ class KNNGraph(GraphBuilder):
|
||||
N points in D-dimensional space.
|
||||
:type points: torch.Tensor | LabelTensor
|
||||
:param int k: The number of nearest neighbors to find for each point.
|
||||
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
|
||||
:return: A tensor of shape (2, E), where E is the number of
|
||||
edges, representing the edge indices of the KNN graph.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
dist = torch.cdist(points, points, p=2)
|
||||
@@ -343,6 +372,11 @@ class LabelBatch(Batch):
|
||||
def from_data_list(cls, data_list):
|
||||
"""
|
||||
Create a Batch object from a list of Data objects.
|
||||
|
||||
:param data_list: List of Data/Graph objects
|
||||
:type data_list: list[Data] | list[Graph]
|
||||
:return: A Batch object containing the data in the list
|
||||
:rtype: Batch
|
||||
"""
|
||||
# Store the labels of Data/Graph objects (all data have the same labels)
|
||||
# If the data do not contain labels, labels is an empty dictionary,
|
||||
|
||||
Reference in New Issue
Block a user