Documentation and docstring graph and data
This commit is contained in:
committed by
Nicola Demo
parent
6ce0bafc2b
commit
635e3b3a75
@@ -23,16 +23,24 @@ class DummyDataloader:
|
|||||||
|
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
"""
|
"""
|
||||||
param dataset: The dataset object to be processed.
|
Preprare a dataloader object which will return the entire dataset
|
||||||
:notes:
|
in a single batch. Depending on the number of GPUs, the dataset we
|
||||||
- **Distributed Environment**:
|
have the following cases:
|
||||||
- Divides the dataset across processes using the
|
|
||||||
rank and world size.
|
- **Distributed Environment** (multiple GPUs):
|
||||||
- Fetches only the portion of data corresponding to
|
- Divides the dataset across processes using the rank and world
|
||||||
the current process.
|
size.
|
||||||
- **Non-Distributed Environment**:
|
- Fetches only the portion of data corresponding to the current
|
||||||
- Fetches the entire dataset.
|
process.
|
||||||
|
- **Non-Distributed Environment** (single GPU):
|
||||||
|
- Fetches the entire dataset.
|
||||||
|
|
||||||
|
:param dataset: The dataset object to be processed.
|
||||||
|
:type dataset: PinaDataset
|
||||||
|
|
||||||
|
.. note:: This data loader is used when the batch size is None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (
|
if (
|
||||||
torch.distributed.is_available()
|
torch.distributed.is_available()
|
||||||
and torch.distributed.is_initialized()
|
and torch.distributed.is_initialized()
|
||||||
@@ -67,23 +75,50 @@ class Collator:
|
|||||||
Class used to collate the batch
|
Class used to collate the batch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_conditions_lengths, dataset=None):
|
def __init__(
|
||||||
|
self, max_conditions_lengths, automatic_batching, dataset=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the object, setting the collate function based on whether
|
||||||
|
automatic batching is enabled or not.
|
||||||
|
|
||||||
|
:param dict max_conditions_lengths: dict containing the maximum number
|
||||||
|
of data points to consider in a single batch for each condition.
|
||||||
|
:param PinaDataset dataset: The dataset where the data is stored.
|
||||||
|
"""
|
||||||
|
|
||||||
self.max_conditions_lengths = max_conditions_lengths
|
self.max_conditions_lengths = max_conditions_lengths
|
||||||
|
# Set the collate function based on the batching strategy
|
||||||
|
# collate_pina_dataloader is used when automatic batching is disabled
|
||||||
|
# collate_torch_dataloader is used when automatic batching is enabled
|
||||||
self.callable_function = (
|
self.callable_function = (
|
||||||
self._collate_custom_dataloader
|
self._collate_torch_dataloader
|
||||||
if max_conditions_lengths is None
|
if automatic_batching
|
||||||
else (self._collate_standard_dataloader)
|
else (self._collate_pina_dataloader)
|
||||||
)
|
)
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
||||||
|
# Set the function which performs the actual collation
|
||||||
if isinstance(self.dataset, PinaTensorDataset):
|
if isinstance(self.dataset, PinaTensorDataset):
|
||||||
|
# If the dataset is a PinaTensorDataset, use this collate function
|
||||||
self._collate = self._collate_tensor_dataset
|
self._collate = self._collate_tensor_dataset
|
||||||
else:
|
else:
|
||||||
|
# If the dataset is a PinaDataset, use this collate function
|
||||||
self._collate = self._collate_graph_dataset
|
self._collate = self._collate_graph_dataset
|
||||||
|
|
||||||
def _collate_custom_dataloader(self, batch):
|
def _collate_pina_dataloader(self, batch):
|
||||||
|
"""
|
||||||
|
Function used to create a batch when automatic batching is disabled.
|
||||||
|
|
||||||
|
:param list(int) batch: List of integers representing the indices of
|
||||||
|
the data points to be fetched.
|
||||||
|
:return: Dictionary containing the data points fetched from the dataset.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
# Call the fetch_from_idx_list method of the dataset
|
||||||
return self.dataset.fetch_from_idx_list(batch)
|
return self.dataset.fetch_from_idx_list(batch)
|
||||||
|
|
||||||
def _collate_standard_dataloader(self, batch):
|
def _collate_torch_dataloader(self, batch):
|
||||||
"""
|
"""
|
||||||
Function used to collate the batch
|
Function used to collate the batch
|
||||||
"""
|
"""
|
||||||
@@ -112,6 +147,19 @@ class Collator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _collate_tensor_dataset(data_list):
|
def _collate_tensor_dataset(data_list):
|
||||||
|
"""
|
||||||
|
Function used to collate the data when the dataset is a
|
||||||
|
`PinaTensorDataset`.
|
||||||
|
|
||||||
|
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
|
||||||
|
collated.
|
||||||
|
:type data_list: list(torch.Tensor) | list(LabelTensor)
|
||||||
|
:raises RuntimeError: If the data is not a `torch.Tensor` or a
|
||||||
|
`LabelTensor`.
|
||||||
|
:return: Batch of data
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
return LabelTensor.stack(data_list)
|
return LabelTensor.stack(data_list)
|
||||||
if isinstance(data_list[0], torch.Tensor):
|
if isinstance(data_list[0], torch.Tensor):
|
||||||
@@ -119,15 +167,36 @@ class Collator:
|
|||||||
raise RuntimeError("Data must be Tensors or LabelTensor ")
|
raise RuntimeError("Data must be Tensors or LabelTensor ")
|
||||||
|
|
||||||
def _collate_graph_dataset(self, data_list):
|
def _collate_graph_dataset(self, data_list):
|
||||||
|
"""
|
||||||
|
Function used to collate the data when the dataset is a
|
||||||
|
`PinaGraphDataset`.
|
||||||
|
|
||||||
|
:param data_list: List of `Data` or `Graph` to be collated.
|
||||||
|
:type data_list: list(Data) | list(Graph)
|
||||||
|
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
|
||||||
|
:return: Batch of data
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
return LabelTensor.cat(data_list)
|
return LabelTensor.cat(data_list)
|
||||||
if isinstance(data_list[0], torch.Tensor):
|
if isinstance(data_list[0], torch.Tensor):
|
||||||
return torch.cat(data_list)
|
return torch.cat(data_list)
|
||||||
if isinstance(data_list[0], Data):
|
if isinstance(data_list[0], Data):
|
||||||
return self.dataset.create_graph_batch(data_list)
|
return self.dataset.create_batch(data_list)
|
||||||
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
|
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
|
"""
|
||||||
|
Call the function to collate the batch, defined in __init__.
|
||||||
|
|
||||||
|
:param batch: list of indices or list of retrieved data
|
||||||
|
:type batch: list(int) | list(dict)
|
||||||
|
:return: Dictionary containing the data points fetched from the dataset,
|
||||||
|
collated.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
|
||||||
return self.callable_function(batch)
|
return self.callable_function(batch)
|
||||||
|
|
||||||
|
|
||||||
@@ -137,6 +206,16 @@ class PinaSampler:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, dataset, shuffle):
|
def __new__(cls, dataset, shuffle):
|
||||||
|
"""
|
||||||
|
Create the sampler instance, according to shuffle and whether the
|
||||||
|
environment is distributed or not.
|
||||||
|
|
||||||
|
:param PinaDataset dataset: The dataset to be sampled.
|
||||||
|
:param bool shuffle: whether to shuffle the dataset or not before
|
||||||
|
sampling.
|
||||||
|
:return: The sampler instance.
|
||||||
|
:rtype: torch.utils.data.Sampler
|
||||||
|
"""
|
||||||
|
|
||||||
if (
|
if (
|
||||||
torch.distributed.is_available()
|
torch.distributed.is_available()
|
||||||
@@ -173,29 +252,24 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
Initialize the object, creating datasets based on the input problem.
|
Initialize the object, creating datasets based on the input problem.
|
||||||
|
|
||||||
:param problem: The problem defining the dataset.
|
:param AbstractProblem problem: The problem containing the data on which
|
||||||
:type problem: AbstractProblem
|
to train/test the model.
|
||||||
:param train_size: Fraction or number of elements in the training split.
|
:param float train_size: Fraction or number of elements in the training
|
||||||
:type train_size: float
|
split.
|
||||||
:param test_size: Fraction or number of elements in the test split.
|
:param float test_size: Fraction or number of elements in the test
|
||||||
:type test_size: float
|
split.
|
||||||
:param val_size: Fraction or number of elements in the validation split.
|
:param float val_size: Fraction or number of elements in the validation
|
||||||
:type val_size: float
|
split.
|
||||||
:param batch_size: Batch size used for training. If None, the entire
|
:param batch_size: The batch size used for training. If `None`, the
|
||||||
dataset is used per batch.
|
entire dataset is used per batch.
|
||||||
:type batch_size: int or None
|
:type batch_size: int | None
|
||||||
:param shuffle: Whether to shuffle the dataset before splitting.
|
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||||
:type shuffle: bool
|
:param bool repeat: Whether to repeat the dataset indefinitely.
|
||||||
:param repeat: Whether to repeat the dataset indefinitely.
|
|
||||||
:type repeat: bool
|
|
||||||
:param automatic_batching: Whether to enable automatic batching.
|
:param automatic_batching: Whether to enable automatic batching.
|
||||||
:type automatic_batching: bool
|
:param int num_workers: Number of worker threads for data loading.
|
||||||
:param num_workers: Number of worker threads for data loading.
|
|
||||||
Default 0 (serial loading)
|
Default 0 (serial loading)
|
||||||
:type num_workers: int
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
:param pin_memory: Whether to use pinned memory for faster data
|
|
||||||
transfer to GPU. (Default False)
|
transfer to GPU. (Default False)
|
||||||
:type pin_memory: bool
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -365,10 +439,14 @@ class PinaDataModule(LightningDataModule):
|
|||||||
sampler = PinaSampler(dataset, shuffle)
|
sampler = PinaSampler(dataset, shuffle)
|
||||||
if self.automatic_batching:
|
if self.automatic_batching:
|
||||||
collate = Collator(
|
collate = Collator(
|
||||||
self.find_max_conditions_lengths(split), dataset=dataset
|
self.find_max_conditions_lengths(split),
|
||||||
|
self.automatic_batching,
|
||||||
|
dataset=dataset,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
collate = Collator(None, dataset=dataset)
|
collate = Collator(
|
||||||
|
None, self.automatic_batching, dataset=dataset
|
||||||
|
)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
self.batch_size,
|
self.batch_size,
|
||||||
@@ -413,23 +491,51 @@ class PinaDataModule(LightningDataModule):
|
|||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
"""
|
"""
|
||||||
Create the training dataloader
|
Create the training dataloader
|
||||||
|
|
||||||
|
:return: The training dataloader
|
||||||
|
:rtype: DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("train", self.train_dataset)
|
return self._create_dataloader("train", self.train_dataset)
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
"""
|
"""
|
||||||
Create the testing dataloader
|
Create the testing dataloader
|
||||||
|
|
||||||
|
:return: The testing dataloader
|
||||||
|
:rtype: DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("test", self.test_dataset)
|
return self._create_dataloader("test", self.test_dataset)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||||
|
"""
|
||||||
|
Transfer the batch to the device. This method is called in the
|
||||||
|
training loop and is used to transfer the batch to the device.
|
||||||
|
This method is used when the batch size is None: batch has already
|
||||||
|
been transferred to the device.
|
||||||
|
|
||||||
|
:param list(tuple) batch: list of tuple where the first element of the
|
||||||
|
tuple is the condition name and the second element is the data.
|
||||||
|
:param device: device to which the batch is transferred
|
||||||
|
:type device: torch.device
|
||||||
|
:param dataloader_idx: index of the dataloader
|
||||||
|
:type dataloader_idx: int
|
||||||
|
:return: The batch transferred to the device.
|
||||||
|
:rtype: list(tuple)
|
||||||
|
"""
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||||
"""
|
"""
|
||||||
Transfer the batch to the device. This method is called in the
|
Transfer the batch to the device. This method is called in the
|
||||||
training loop and is used to transfer the batch to the device.
|
training loop and is used to transfer the batch to the device.
|
||||||
|
|
||||||
|
:param dict batch: The batch to be transferred to the device.
|
||||||
|
:param device: The device to which the batch is transferred.
|
||||||
|
:type device: torch.device
|
||||||
|
:param int dataloader_idx: The index of the dataloader.
|
||||||
|
:return: The batch transferred to the device.
|
||||||
|
:rtype: list(tuple)
|
||||||
"""
|
"""
|
||||||
batch = [
|
batch = [
|
||||||
(
|
(
|
||||||
@@ -456,7 +562,10 @@ class PinaDataModule(LightningDataModule):
|
|||||||
@property
|
@property
|
||||||
def input(self):
|
def input(self):
|
||||||
"""
|
"""
|
||||||
# TODO
|
Return all the input points coming from all the datasets.
|
||||||
|
|
||||||
|
:return: The input points for training.
|
||||||
|
:rtype dict
|
||||||
"""
|
"""
|
||||||
to_return = {}
|
to_return = {}
|
||||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||||
@@ -464,5 +573,5 @@ class PinaDataModule(LightningDataModule):
|
|||||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||||
to_return["val"] = self.val_dataset.input
|
to_return["val"] = self.val_dataset.input
|
||||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||||
to_return = self.test_dataset.input
|
to_return["test"] = self.test_dataset.input
|
||||||
return to_return
|
return to_return
|
||||||
|
|||||||
@@ -10,13 +10,31 @@ from ..graph import Graph, LabelBatch
|
|||||||
|
|
||||||
class PinaDatasetFactory:
|
class PinaDatasetFactory:
|
||||||
"""
|
"""
|
||||||
Factory class for the PINA dataset. Depending on the type inside the
|
Factory class for the PINA dataset.
|
||||||
conditions it creates a different dataset object:
|
|
||||||
- PinaTensorDataset for torch.Tensor
|
Depending on the type inside the conditions, it creates a different dataset
|
||||||
- PinaGraphDataset for list of torch_geometric.data.Data objects
|
object:
|
||||||
|
|
||||||
|
- :class:`PinaTensorDataset` for `torch.Tensor`
|
||||||
|
- :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data`
|
||||||
|
objects
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, conditions_dict, **kwargs):
|
def __new__(cls, conditions_dict, **kwargs):
|
||||||
|
"""
|
||||||
|
Instantiate the appropriate subclass of :class:`PinaDataset`.
|
||||||
|
|
||||||
|
If a graph is present in the conditions, returns a
|
||||||
|
:class:`PinaGraphDataset`, otherwise returns a
|
||||||
|
:class:`PinaTensorDataset`.
|
||||||
|
|
||||||
|
:param dict conditions_dict: Dictionary containing the conditions.
|
||||||
|
:return: A subclass of :class:`PinaDataset`.
|
||||||
|
:rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset`
|
||||||
|
|
||||||
|
:raises ValueError: If an empty dictionary is provided.
|
||||||
|
"""
|
||||||
|
|
||||||
# Check if conditions_dict is empty
|
# Check if conditions_dict is empty
|
||||||
if len(conditions_dict) == 0:
|
if len(conditions_dict) == 0:
|
||||||
raise ValueError("No conditions provided")
|
raise ValueError("No conditions provided")
|
||||||
@@ -31,9 +49,21 @@ class PinaDatasetFactory:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_graph_dataset(conditions_dict):
|
def _is_graph_dataset(conditions_dict):
|
||||||
|
"""
|
||||||
|
Check if a graph is present in the conditions.
|
||||||
|
|
||||||
|
:param conditions_dict: Dictionary containing the conditions.
|
||||||
|
:type conditions_dict: dict
|
||||||
|
:return: True if a graph is present in the conditions, False otherwise
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Iterate over the conditions dictionary
|
||||||
for v in conditions_dict.values():
|
for v in conditions_dict.values():
|
||||||
|
# Iterate over the values of the current condition
|
||||||
for cond in v.values():
|
for cond in v.values():
|
||||||
if isinstance(cond, (Data, Graph, list)):
|
# Check if the current value is a list of Data objects
|
||||||
|
if isinstance(cond, (Data, Graph, list, tuple)):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -46,6 +76,19 @@ class PinaDataset(Dataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize the :class:`PinaDataset`.
|
||||||
|
|
||||||
|
Stores the conditions dictionary, the maximum number of conditions to
|
||||||
|
consider, and the automatic batching flag.
|
||||||
|
|
||||||
|
:param dict conditions_dict: Dictionary containing the conditions.
|
||||||
|
:param dict max_conditions_lengths: Maximum number of data points to
|
||||||
|
consider in a single batch for each condition.
|
||||||
|
:param bool automatic_batching: Whether PyTorch automatic batching is
|
||||||
|
enabled in :class:`PinaDataModule`.
|
||||||
|
"""
|
||||||
|
|
||||||
# Store the conditions dictionary
|
# Store the conditions dictionary
|
||||||
self.conditions_dict = conditions_dict
|
self.conditions_dict = conditions_dict
|
||||||
# Store the maximum number of conditions to consider
|
# Store the maximum number of conditions to consider
|
||||||
@@ -63,7 +106,13 @@ class PinaDataset(Dataset):
|
|||||||
self._getitem_func = self._getitem_dummy
|
self._getitem_func = self._getitem_dummy
|
||||||
|
|
||||||
def _get_max_len(self):
|
def _get_max_len(self):
|
||||||
""""""
|
"""
|
||||||
|
Returns the length of the longest condition in the dataset
|
||||||
|
|
||||||
|
:return: Length of the longest condition in the dataset
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
|
||||||
max_len = 0
|
max_len = 0
|
||||||
for condition in self.conditions_dict.values():
|
for condition in self.conditions_dict.values():
|
||||||
max_len = max(max_len, len(condition["input"]))
|
max_len = max(max_len, len(condition["input"]))
|
||||||
@@ -76,10 +125,29 @@ class PinaDataset(Dataset):
|
|||||||
return self._getitem_func(idx)
|
return self._getitem_func(idx)
|
||||||
|
|
||||||
def _getitem_dummy(self, idx):
|
def _getitem_dummy(self, idx):
|
||||||
|
"""
|
||||||
|
Return the index itself. This is used when automatic batching is
|
||||||
|
disabled to postpone the data retrieval to the dataloader.
|
||||||
|
|
||||||
|
:param idx: Index
|
||||||
|
:type idx: int
|
||||||
|
:return: Index
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
|
||||||
# If automatic batching is disabled, return the data at the given index
|
# If automatic batching is disabled, return the data at the given index
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def _getitem_int(self, idx):
|
def _getitem_int(self, idx):
|
||||||
|
"""
|
||||||
|
Return the data at the given index in the dataset. This is used when
|
||||||
|
automatic batching is enabled.
|
||||||
|
|
||||||
|
:param int idx: Index
|
||||||
|
:return: A dictionary containing the data at the given index
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
|
||||||
# If automatic batching is enabled, return the data at the given index
|
# If automatic batching is enabled, return the data at the given index
|
||||||
return {
|
return {
|
||||||
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
|
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
|
||||||
@@ -121,7 +189,14 @@ class PinaDataset(Dataset):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _retrive_data(self, data, idx_list):
|
def _retrive_data(self, data, idx_list):
|
||||||
pass
|
"""
|
||||||
|
Retrieve data from the dataset given a list of indices
|
||||||
|
|
||||||
|
:param dict data: Dictionary containing the data
|
||||||
|
:param list idx_list: List of indices to retrieve
|
||||||
|
:return: Dictionary containing the data at the given indices
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class PinaTensorDataset(PinaDataset):
|
class PinaTensorDataset(PinaDataset):
|
||||||
@@ -131,12 +206,26 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
|
|
||||||
# Override _retrive_data method for torch.Tensor data
|
# Override _retrive_data method for torch.Tensor data
|
||||||
def _retrive_data(self, data, idx_list):
|
def _retrive_data(self, data, idx_list):
|
||||||
|
"""
|
||||||
|
Retrieve data from the dataset given a list of indices
|
||||||
|
|
||||||
|
:param data: Dictionary containing the data
|
||||||
|
(only torch.Tensor/LableTensor)
|
||||||
|
:type data: dict
|
||||||
|
:param list(int) idx_list: indices to retrieve
|
||||||
|
:return: Dictionary containing the data at the given indices
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
|
||||||
return {k: v[idx_list] for k, v in data.items()}
|
return {k: v[idx_list] for k, v in data.items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input(self):
|
def input(self):
|
||||||
"""
|
"""
|
||||||
Method to return input points for training.
|
Method to return all input points from the dataset.
|
||||||
|
|
||||||
|
:return: Dictionary containing the input points
|
||||||
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||||
|
|
||||||
@@ -146,15 +235,33 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
Class for the PINA dataset with torch_geometric.data.Data data
|
Class for the PINA dataset with torch_geometric.data.Data data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _create_graph_batch_from_list(self, data):
|
def _create_graph_batch(self, data):
|
||||||
|
"""
|
||||||
|
Create a LabelBatch object from a list of Data objects.
|
||||||
|
|
||||||
|
:param data: List of Data or Graph objects
|
||||||
|
:type data: list(Data) | list(Graph)
|
||||||
|
:return: LabelBatch object all the graph collated in a single batch
|
||||||
|
disconnected graphs.
|
||||||
|
:rtype: LabelBatch
|
||||||
|
"""
|
||||||
batch = LabelBatch.from_data_list(data)
|
batch = LabelBatch.from_data_list(data)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def _create_output_batch(self, data):
|
def _create_tensor_batch(self, data):
|
||||||
|
"""
|
||||||
|
Create a torch.Tensor object from a list of torch.Tensor objects.
|
||||||
|
|
||||||
|
:param data: torch.Tensor object of shape (N, ...) where N is the
|
||||||
|
number of data points.
|
||||||
|
:type data: torch.Tensor | LabelTensor
|
||||||
|
:return: reshaped torch.Tensor or LabelTensor object
|
||||||
|
:rtype: torch.Tensor | LabelTensor
|
||||||
|
"""
|
||||||
out = data.reshape(-1, *data.shape[2:])
|
out = data.reshape(-1, *data.shape[2:])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def create_graph_batch(self, data):
|
def create_batch(self, data):
|
||||||
"""
|
"""
|
||||||
Create a Batch object from a list of Data objects.
|
Create a Batch object from a list of Data objects.
|
||||||
|
|
||||||
@@ -163,20 +270,29 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
:return: Batch object
|
:return: Batch object
|
||||||
:rtype: Batch or PinaBatch
|
:rtype: Batch or PinaBatch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data[0], Data):
|
if isinstance(data[0], Data):
|
||||||
return self._create_graph_batch_from_list(data)
|
return self._create_graph_batch(data)
|
||||||
return self._create_output_batch(data)
|
return self._create_tensor_batch(data)
|
||||||
|
|
||||||
# Override _retrive_data method for graph handling
|
# Override _retrive_data method for graph handling
|
||||||
def _retrive_data(self, data, idx_list):
|
def _retrive_data(self, data, idx_list):
|
||||||
|
"""
|
||||||
|
Retrieve data from the dataset given a list of indices
|
||||||
|
|
||||||
|
:param dict data: dictionary containing the data
|
||||||
|
:param list idx_list: list of indices to retrieve
|
||||||
|
:return: dictionary containing the data at the given indices
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
# Return the data from the current condition
|
# Return the data from the current condition
|
||||||
# If the data is a list of Data objects, create a Batch object
|
# If the data is a list of Data objects, create a Batch object
|
||||||
# If the data is a list of torch.Tensor objects, create a torch.Tensor
|
# If the data is a list of torch.Tensor objects, create a torch.Tensor
|
||||||
return {
|
return {
|
||||||
k: (
|
k: (
|
||||||
self._create_graph_batch_from_list([v[i] for i in idx_list])
|
self._create_graph_batch([v[i] for i in idx_list])
|
||||||
if isinstance(v, list)
|
if isinstance(v, list)
|
||||||
else self._create_output_batch(v[idx_list])
|
else self._create_tensor_batch(v[idx_list])
|
||||||
)
|
)
|
||||||
for k, v in data.items()
|
for k, v in data.items()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ class Graph(Data):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
Instantiates a new instance of the Graph class, performing type
|
||||||
|
consistency checks.
|
||||||
|
|
||||||
:param kwargs: Parameters to construct the Graph object.
|
:param kwargs: Parameters to construct the Graph object.
|
||||||
:return: A new instance of the Graph class.
|
:return: A new instance of the Graph class.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
@@ -42,7 +45,10 @@ class Graph(Data):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Graph object.
|
Initialize the Graph object by setting the node features, edge index,
|
||||||
|
edge attributes, and positions. The edge index is preprocessed to make
|
||||||
|
the graph undirected if required. For more details, see the
|
||||||
|
:meth: `torch_geometric.data.Data`
|
||||||
|
|
||||||
:param x: Optional tensor of node features (N, F) where F is the number
|
:param x: Optional tensor of node features (N, F) where F is the number
|
||||||
of features per node.
|
of features per node.
|
||||||
@@ -69,6 +75,13 @@ class Graph(Data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_type_consistency(self, **kwargs):
|
def _check_type_consistency(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Check the consistency of the types of the input data.
|
||||||
|
|
||||||
|
:param kwargs: Attributes to be checked for consistency.
|
||||||
|
:type kwargs: dict
|
||||||
|
"""
|
||||||
|
|
||||||
# default types, specified in cls.__new__, by default they are Nont
|
# default types, specified in cls.__new__, by default they are Nont
|
||||||
# if specified in **kwargs they get override
|
# if specified in **kwargs they get override
|
||||||
x, pos, edge_index, edge_attr = None, None, None, None
|
x, pos, edge_index, edge_attr = None, None, None, None
|
||||||
@@ -92,8 +105,10 @@ class Graph(Data):
|
|||||||
def _check_pos_consistency(pos):
|
def _check_pos_consistency(pos):
|
||||||
"""
|
"""
|
||||||
Check if the position tensor is consistent.
|
Check if the position tensor is consistent.
|
||||||
|
|
||||||
:param torch.Tensor pos: The position tensor.
|
:param torch.Tensor pos: The position tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if pos is not None:
|
if pos is not None:
|
||||||
check_consistency(pos, (torch.Tensor, LabelTensor))
|
check_consistency(pos, (torch.Tensor, LabelTensor))
|
||||||
if pos.ndim != 2:
|
if pos.ndim != 2:
|
||||||
@@ -103,8 +118,10 @@ class Graph(Data):
|
|||||||
def _check_edge_index_consistency(edge_index):
|
def _check_edge_index_consistency(edge_index):
|
||||||
"""
|
"""
|
||||||
Check if the edge index is consistent.
|
Check if the edge index is consistent.
|
||||||
|
|
||||||
:param torch.Tensor edge_index: The edge index tensor.
|
:param torch.Tensor edge_index: The edge index tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
check_consistency(edge_index, (torch.Tensor, LabelTensor))
|
check_consistency(edge_index, (torch.Tensor, LabelTensor))
|
||||||
if edge_index.ndim != 2:
|
if edge_index.ndim != 2:
|
||||||
raise ValueError("edge_index must be a 2D tensor.")
|
raise ValueError("edge_index must be a 2D tensor.")
|
||||||
@@ -114,11 +131,13 @@ class Graph(Data):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_edge_attr_consistency(edge_attr, edge_index):
|
def _check_edge_attr_consistency(edge_attr, edge_index):
|
||||||
"""
|
"""
|
||||||
Check if the edge attr is consistent.
|
Check if the edge attribute tensor is consistent in type and shape
|
||||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
with the edge index.
|
||||||
|
|
||||||
|
:param torch.Tensor edge_attr: The edge attribute tensor.
|
||||||
:param torch.Tensor edge_index: The edge index tensor.
|
:param torch.Tensor edge_index: The edge index tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if edge_attr is not None:
|
if edge_attr is not None:
|
||||||
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
|
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
|
||||||
if edge_attr.ndim != 2:
|
if edge_attr.ndim != 2:
|
||||||
@@ -134,10 +153,13 @@ class Graph(Data):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_x_consistency(x, pos=None):
|
def _check_x_consistency(x, pos=None):
|
||||||
"""
|
"""
|
||||||
Check if the input tensor x is consistent with the position tensor pos.
|
Check if the input tensor x is consistent with the position tensor
|
||||||
|
`pos`.
|
||||||
|
|
||||||
:param torch.Tensor x: The input tensor.
|
:param torch.Tensor x: The input tensor.
|
||||||
:param torch.Tensor pos: The position tensor.
|
:param torch.Tensor pos: The position tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if x is not None:
|
if x is not None:
|
||||||
check_consistency(x, (torch.Tensor, LabelTensor))
|
check_consistency(x, (torch.Tensor, LabelTensor))
|
||||||
if x.ndim != 2:
|
if x.ndim != 2:
|
||||||
@@ -152,22 +174,24 @@ class Graph(Data):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _preprocess_edge_index(edge_index, undirected):
|
def _preprocess_edge_index(edge_index, undirected):
|
||||||
"""
|
"""
|
||||||
Preprocess the edge index.
|
Preprocess the edge index to make the graph undirected (if required).
|
||||||
|
|
||||||
:param torch.Tensor edge_index: The edge index.
|
:param torch.Tensor edge_index: The edge index.
|
||||||
:param bool undirected: Whether the graph is undirected.
|
:param bool undirected: Whether the graph is undirected.
|
||||||
:return: The preprocessed edge index.
|
:return: The preprocessed edge index.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if undirected:
|
if undirected:
|
||||||
edge_index = to_undirected(edge_index)
|
edge_index = to_undirected(edge_index)
|
||||||
return edge_index
|
return edge_index
|
||||||
|
|
||||||
def extract(self, labels, attr="x"):
|
def extract(self, labels, attr="x"):
|
||||||
"""
|
"""
|
||||||
Perform extraction of labels on node features (x)
|
Perform extraction of labels from the attribute specified by `attr`.
|
||||||
|
|
||||||
:param labels: Labels to extract
|
:param labels: Labels to extract
|
||||||
:type labels: list[str] | tuple[str] | str
|
:type labels: list[str] | tuple[str] | str | dict
|
||||||
:return: Batch object with extraction performed on x
|
:return: Batch object with extraction performed on x
|
||||||
:rtype: PinaBatch
|
:rtype: PinaBatch
|
||||||
"""
|
"""
|
||||||
@@ -193,21 +217,24 @@ class GraphBuilder:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a new instance of the Graph class.
|
Compute the edge attributes and create a new instance of the Graph
|
||||||
|
class.
|
||||||
|
|
||||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||||
points in D-dimensional space.
|
points in D-dimensional space.
|
||||||
:type pos: torch.Tensor | LabelTensor
|
:type pos: torch.Tensor or LabelTensor
|
||||||
:param edge_index: A tensor of shape (2, E) representing the indices of
|
:param edge_index: A tensor of shape (2, E) representing the indices of
|
||||||
the graph's edges.
|
the graph's edges.
|
||||||
:type edge_index: torch.Tensor
|
:type edge_index: torch.Tensor
|
||||||
:param x: Optional tensor of node features (N, F) where F is the number
|
:param x: Optional tensor of node features of shape (N, F), where F is
|
||||||
of features per node.
|
the number of features per node.
|
||||||
:type x: torch.Tensor, LabelTensor
|
:type x: torch.Tensor | LabelTensor, optional
|
||||||
:param bool edge_attr: Optional edge attributes (E, F) where F is the
|
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
|
||||||
number of features per edge.
|
where F is the number of features per edge.
|
||||||
:param callable custom_edge_func: A custom function to compute edge
|
:type edge_attr: torch.Tensor, optional
|
||||||
attributes.
|
:param custom_edge_func: A custom function to compute edge attributes.
|
||||||
|
If provided, overrides `edge_attr`.
|
||||||
|
:type custom_edge_func: callable, optional
|
||||||
:param kwargs: Additional keyword arguments passed to the Graph class
|
:param kwargs: Additional keyword arguments passed to the Graph class
|
||||||
constructor.
|
constructor.
|
||||||
:return: A Graph instance constructed using the provided information.
|
:return: A Graph instance constructed using the provided information.
|
||||||
@@ -249,18 +276,18 @@ class RadiusGraph(GraphBuilder):
|
|||||||
|
|
||||||
def __new__(cls, pos, radius, **kwargs):
|
def __new__(cls, pos, radius, **kwargs):
|
||||||
"""
|
"""
|
||||||
Creates a new instance of the Graph class using a radius-based graph
|
Extends the `GraphBuilder` class to compute edge_index based on a
|
||||||
construction.
|
radius. Each point is connected to all the points within the radius.
|
||||||
|
|
||||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||||
points in D-dimensional space.
|
points in D-dimensional space.
|
||||||
:type pos: torch.Tensor | LabelTensor
|
:type pos: torch.Tensor or LabelTensor
|
||||||
:param float radius: The radius within which points are connected.
|
:param radius: The radius within which points are connected.
|
||||||
:Keyword Arguments:
|
:type radius: float
|
||||||
The additional keyword arguments to be passed to GraphBuilder
|
:param kwargs: Additional keyword arguments to be passed to the
|
||||||
and Graph classes
|
`GraphBuilder` and `Graph` constructors.
|
||||||
:return: Graph instance containg the information passed in input and
|
:return: A `Graph` instance containing the input information and the
|
||||||
the computed edge_index
|
computed edge_index.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
edge_index = cls.compute_radius_graph(pos, radius)
|
edge_index = cls.compute_radius_graph(pos, radius)
|
||||||
@@ -269,7 +296,8 @@ class RadiusGraph(GraphBuilder):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_radius_graph(points, radius):
|
def compute_radius_graph(points, radius):
|
||||||
"""
|
"""
|
||||||
Computes a radius-based graph for a given set of points.
|
Computes edge_index for a given set of points base on the radius.
|
||||||
|
Each point is connected to all the points within the radius.
|
||||||
|
|
||||||
:param points: A tensor of shape (N, D) representing the positions of
|
:param points: A tensor of shape (N, D) representing the positions of
|
||||||
N points in D-dimensional space.
|
N points in D-dimensional space.
|
||||||
@@ -295,7 +323,7 @@ class KNNGraph(GraphBuilder):
|
|||||||
def __new__(cls, pos, neighbours, **kwargs):
|
def __new__(cls, pos, neighbours, **kwargs):
|
||||||
"""
|
"""
|
||||||
Creates a new instance of the Graph class using k-nearest neighbors
|
Creates a new instance of the Graph class using k-nearest neighbors
|
||||||
to compute edge_index.
|
algorithm to define the edges.
|
||||||
|
|
||||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||||
points in D-dimensional space.
|
points in D-dimensional space.
|
||||||
@@ -323,8 +351,9 @@ class KNNGraph(GraphBuilder):
|
|||||||
N points in D-dimensional space.
|
N points in D-dimensional space.
|
||||||
:type points: torch.Tensor | LabelTensor
|
:type points: torch.Tensor | LabelTensor
|
||||||
:param int k: The number of nearest neighbors to find for each point.
|
:param int k: The number of nearest neighbors to find for each point.
|
||||||
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
|
:return: A tensor of shape (2, E), where E is the number of
|
||||||
edges, representing the edge indices of the KNN graph.
|
edges, representing the edge indices of the KNN graph.
|
||||||
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dist = torch.cdist(points, points, p=2)
|
dist = torch.cdist(points, points, p=2)
|
||||||
@@ -343,6 +372,11 @@ class LabelBatch(Batch):
|
|||||||
def from_data_list(cls, data_list):
|
def from_data_list(cls, data_list):
|
||||||
"""
|
"""
|
||||||
Create a Batch object from a list of Data objects.
|
Create a Batch object from a list of Data objects.
|
||||||
|
|
||||||
|
:param data_list: List of Data/Graph objects
|
||||||
|
:type data_list: list[Data] | list[Graph]
|
||||||
|
:return: A Batch object containing the data in the list
|
||||||
|
:rtype: Batch
|
||||||
"""
|
"""
|
||||||
# Store the labels of Data/Graph objects (all data have the same labels)
|
# Store the labels of Data/Graph objects (all data have the same labels)
|
||||||
# If the data do not contain labels, labels is an empty dictionary,
|
# If the data do not contain labels, labels is an empty dictionary,
|
||||||
|
|||||||
Reference in New Issue
Block a user