Update doc data

This commit is contained in:
FilippoOlivo
2025-03-12 09:25:31 +01:00
committed by Nicola Demo
parent 415dbcc72a
commit 1ba3c7a6e1
2 changed files with 17 additions and 16 deletions

View File

@@ -112,7 +112,7 @@ class Collator:
""" """
Function used to create a batch when automatic batching is disabled. Function used to create a batch when automatic batching is disabled.
:param list(int) batch: List of integers representing the indices of :param list[int] batch: List of integers representing the indices of
the data points to be fetched. the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset. :return: Dictionary containing the data points fetched from the dataset.
:rtype: dict :rtype: dict
@@ -124,7 +124,7 @@ class Collator:
""" """
Function used to collate the batch Function used to collate the batch
:param list(dict) batch: List of retrieved data. :param list[dict] batch: List of retrieved data.
:return: Dictionary containing the data points fetched from the dataset, :return: Dictionary containing the data points fetched from the dataset,
collated. collated.
:rtype: dict :rtype: dict
@@ -160,7 +160,7 @@ class Collator:
:class:`PinaTensorDataset`. :class:`PinaTensorDataset`.
:param data_list: Elements to be collated. :param data_list: Elements to be collated.
:type data_list: list(torch.Tensor) | list(LabelTensor) :type data_list: list[torch.Tensor] | list[LabelTensor]
:return: Batch of data. :return: Batch of data.
:rtype: dict :rtype: dict
@@ -180,7 +180,7 @@ class Collator:
:class:`PinaGraphDataset`. :class:`PinaGraphDataset`.
:param data_list: Elememts to be collated. :param data_list: Elememts to be collated.
:type data_list: list(torch_geometric.data.Data) | list(Graph) :type data_list: list[torch_geometric.data.Data] | list[Graph]
:return: Batch of data. :return: Batch of data.
:rtype: dict :rtype: dict
@@ -206,7 +206,7 @@ class Collator:
during class initialization. during class initialization.
:param batch: List of retrieved data or sampled indices. :param batch: List of retrieved data or sampled indices.
:type batch: list(int) | list(dict) :type batch: list[int] | list[dict]
:return: Dictionary containing the data points fetched from the dataset, :return: Dictionary containing the data points fetched from the dataset,
collated. collated.
:rtype: dict :rtype: dict
@@ -582,12 +582,12 @@ class PinaDataModule(LightningDataModule):
Transfer the batch to the device. This method is used when the batch Transfer the batch to the device. This method is used when the batch
size is None: batch has already been transferred to the device. size is None: batch has already been transferred to the device.
:param list(tuple) batch: list of tuple where the first element of the :param list[tuple] batch: List of tuple where the first element of the
tuple is the condition name and the second element is the data. tuple is the condition name and the second element is the data.
:param torch.device device: device to which the batch is transferred. :param torch.device device: Device to which the batch is transferred.
:param int dataloader_idx: index of the dataloader. :param int dataloader_idx: Index of the dataloader.
:return: The batch transferred to the device. :return: The batch transferred to the device.
:rtype: list(tuple) :rtype: list[tuple]
""" """
return batch return batch
@@ -602,7 +602,7 @@ class PinaDataModule(LightningDataModule):
transferred. transferred.
:param int dataloader_idx: The index of the dataloader. :param int dataloader_idx: The index of the dataloader.
:return: The batch transferred to the device. :return: The batch transferred to the device.
:rtype: list(tuple) :rtype: list[tuple]
""" """
batch = [ batch = [

View File

@@ -175,7 +175,7 @@ class PinaDataset(Dataset, ABC):
Return data from the dataset given a list of indices. Return data from the dataset given a list of indices.
:param idx: List of indices. :param idx: List of indices.
:type idx: list :type idx: list[int]
:return: A dictionary containing the data at the given indices. :return: A dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
""" """
@@ -216,7 +216,7 @@ class PinaTensorDataset(PinaDataset):
:param data: Dictionary containing the data :param data: Dictionary containing the data
(only torch.Tensor/LableTensor). (only torch.Tensor/LableTensor).
:type data: dict :type data: dict
:param list(int) idx_list: indices to retrieve. :param list[int] idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices. :return: Dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
""" """
@@ -246,7 +246,7 @@ class PinaGraphDataset(PinaDataset):
:class:`torch_geometric.data.Data` objects. :class:`torch_geometric.data.Data` objects.
:param data: List of items to collate in a single batch. :param data: List of items to collate in a single batch.
:type data: list(torch_geometric.data.Data) | list(Graph) :type data: list[torch_geometric.data.Data] | list[Graph]
:return: LabelBatch object all the graph collated in a single batch :return: LabelBatch object all the graph collated in a single batch
disconnected graphs. disconnected graphs.
:rtype: LabelBatch :rtype: LabelBatch
@@ -256,7 +256,8 @@ class PinaGraphDataset(PinaDataset):
def _create_tensor_batch(self, data): def _create_tensor_batch(self, data):
""" """
Create a torch.Tensor object from a list of torch.Tensor objects. Reshape properly ``data`` tensor to be processed handle by the graph
based models.
:param data: torch.Tensor object of shape (N, ...) where N is the :param data: torch.Tensor object of shape (N, ...) where N is the
number of data points. number of data points.
@@ -273,7 +274,7 @@ class PinaGraphDataset(PinaDataset):
objects. objects.
:param data: List of items to collate in a single batch. :param data: List of items to collate in a single batch.
:type data: list :type data: list[torch_geometric.data.Data] | list[Graph]
:return: Batch object. :return: Batch object.
:rtype: Batch | PinaBatch :rtype: Batch | PinaBatch
""" """
@@ -288,7 +289,7 @@ class PinaGraphDataset(PinaDataset):
Retrieve data from the dataset given a list of indices. Retrieve data from the dataset given a list of indices.
:param dict data: Dictionary containing the data. :param dict data: Dictionary containing the data.
:param list idx_list: List of indices to retrieve. :param list[int] idx_list: List of indices to retrieve.
:return: Dictionary containing the data at the given indices. :return: Dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
""" """