Update doc data
This commit is contained in:
committed by
Nicola Demo
parent
415dbcc72a
commit
1ba3c7a6e1
@@ -112,7 +112,7 @@ class Collator:
|
||||
"""
|
||||
Function used to create a batch when automatic batching is disabled.
|
||||
|
||||
:param list(int) batch: List of integers representing the indices of
|
||||
:param list[int] batch: List of integers representing the indices of
|
||||
the data points to be fetched.
|
||||
:return: Dictionary containing the data points fetched from the dataset.
|
||||
:rtype: dict
|
||||
@@ -124,7 +124,7 @@ class Collator:
|
||||
"""
|
||||
Function used to collate the batch
|
||||
|
||||
:param list(dict) batch: List of retrieved data.
|
||||
:param list[dict] batch: List of retrieved data.
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
:rtype: dict
|
||||
@@ -160,7 +160,7 @@ class Collator:
|
||||
:class:`PinaTensorDataset`.
|
||||
|
||||
:param data_list: Elements to be collated.
|
||||
:type data_list: list(torch.Tensor) | list(LabelTensor)
|
||||
:type data_list: list[torch.Tensor] | list[LabelTensor]
|
||||
:return: Batch of data.
|
||||
:rtype: dict
|
||||
|
||||
@@ -180,7 +180,7 @@ class Collator:
|
||||
:class:`PinaGraphDataset`.
|
||||
|
||||
:param data_list: Elememts to be collated.
|
||||
:type data_list: list(torch_geometric.data.Data) | list(Graph)
|
||||
:type data_list: list[torch_geometric.data.Data] | list[Graph]
|
||||
:return: Batch of data.
|
||||
:rtype: dict
|
||||
|
||||
@@ -206,7 +206,7 @@ class Collator:
|
||||
during class initialization.
|
||||
|
||||
:param batch: List of retrieved data or sampled indices.
|
||||
:type batch: list(int) | list(dict)
|
||||
:type batch: list[int] | list[dict]
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
:rtype: dict
|
||||
@@ -582,12 +582,12 @@ class PinaDataModule(LightningDataModule):
|
||||
Transfer the batch to the device. This method is used when the batch
|
||||
size is None: batch has already been transferred to the device.
|
||||
|
||||
:param list(tuple) batch: list of tuple where the first element of the
|
||||
:param list[tuple] batch: List of tuple where the first element of the
|
||||
tuple is the condition name and the second element is the data.
|
||||
:param torch.device device: device to which the batch is transferred.
|
||||
:param int dataloader_idx: index of the dataloader.
|
||||
:param torch.device device: Device to which the batch is transferred.
|
||||
:param int dataloader_idx: Index of the dataloader.
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
:rtype: list[tuple]
|
||||
"""
|
||||
|
||||
return batch
|
||||
@@ -602,7 +602,7 @@ class PinaDataModule(LightningDataModule):
|
||||
transferred.
|
||||
:param int dataloader_idx: The index of the dataloader.
|
||||
:return: The batch transferred to the device.
|
||||
:rtype: list(tuple)
|
||||
:rtype: list[tuple]
|
||||
"""
|
||||
|
||||
batch = [
|
||||
|
||||
@@ -175,7 +175,7 @@ class PinaDataset(Dataset, ABC):
|
||||
Return data from the dataset given a list of indices.
|
||||
|
||||
:param idx: List of indices.
|
||||
:type idx: list
|
||||
:type idx: list[int]
|
||||
:return: A dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
"""
|
||||
@@ -216,7 +216,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
:param data: Dictionary containing the data
|
||||
(only torch.Tensor/LableTensor).
|
||||
:type data: dict
|
||||
:param list(int) idx_list: indices to retrieve.
|
||||
:param list[int] idx_list: indices to retrieve.
|
||||
:return: Dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
"""
|
||||
@@ -246,7 +246,7 @@ class PinaGraphDataset(PinaDataset):
|
||||
:class:`torch_geometric.data.Data` objects.
|
||||
|
||||
:param data: List of items to collate in a single batch.
|
||||
:type data: list(torch_geometric.data.Data) | list(Graph)
|
||||
:type data: list[torch_geometric.data.Data] | list[Graph]
|
||||
:return: LabelBatch object all the graph collated in a single batch
|
||||
disconnected graphs.
|
||||
:rtype: LabelBatch
|
||||
@@ -256,7 +256,8 @@ class PinaGraphDataset(PinaDataset):
|
||||
|
||||
def _create_tensor_batch(self, data):
|
||||
"""
|
||||
Create a torch.Tensor object from a list of torch.Tensor objects.
|
||||
Reshape properly ``data`` tensor to be processed handle by the graph
|
||||
based models.
|
||||
|
||||
:param data: torch.Tensor object of shape (N, ...) where N is the
|
||||
number of data points.
|
||||
@@ -273,7 +274,7 @@ class PinaGraphDataset(PinaDataset):
|
||||
objects.
|
||||
|
||||
:param data: List of items to collate in a single batch.
|
||||
:type data: list
|
||||
:type data: list[torch_geometric.data.Data] | list[Graph]
|
||||
:return: Batch object.
|
||||
:rtype: Batch | PinaBatch
|
||||
"""
|
||||
@@ -288,7 +289,7 @@ class PinaGraphDataset(PinaDataset):
|
||||
Retrieve data from the dataset given a list of indices.
|
||||
|
||||
:param dict data: Dictionary containing the data.
|
||||
:param list idx_list: List of indices to retrieve.
|
||||
:param list[int] idx_list: List of indices to retrieve.
|
||||
:return: Dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user