diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 8c9ea12..f1f405d 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -112,7 +112,7 @@ class Collator: """ Function used to create a batch when automatic batching is disabled. - :param list(int) batch: List of integers representing the indices of + :param list[int] batch: List of integers representing the indices of the data points to be fetched. :return: Dictionary containing the data points fetched from the dataset. :rtype: dict @@ -124,7 +124,7 @@ class Collator: """ Function used to collate the batch - :param list(dict) batch: List of retrieved data. + :param list[dict] batch: List of retrieved data. :return: Dictionary containing the data points fetched from the dataset, collated. :rtype: dict @@ -160,7 +160,7 @@ class Collator: :class:`PinaTensorDataset`. :param data_list: Elements to be collated. - :type data_list: list(torch.Tensor) | list(LabelTensor) + :type data_list: list[torch.Tensor] | list[LabelTensor] :return: Batch of data. :rtype: dict @@ -180,7 +180,7 @@ class Collator: :class:`PinaGraphDataset`. :param data_list: Elememts to be collated. - :type data_list: list(torch_geometric.data.Data) | list(Graph) + :type data_list: list[torch_geometric.data.Data] | list[Graph] :return: Batch of data. :rtype: dict @@ -206,7 +206,7 @@ class Collator: during class initialization. :param batch: List of retrieved data or sampled indices. - :type batch: list(int) | list(dict) + :type batch: list[int] | list[dict] :return: Dictionary containing the data points fetched from the dataset, collated. :rtype: dict @@ -582,12 +582,12 @@ class PinaDataModule(LightningDataModule): Transfer the batch to the device. This method is used when the batch size is None: batch has already been transferred to the device. - :param list(tuple) batch: list of tuple where the first element of the + :param list[tuple] batch: List of tuple where the first element of the tuple is the condition name and the second element is the data. - :param torch.device device: device to which the batch is transferred. - :param int dataloader_idx: index of the dataloader. + :param torch.device device: Device to which the batch is transferred. + :param int dataloader_idx: Index of the dataloader. :return: The batch transferred to the device. - :rtype: list(tuple) + :rtype: list[tuple] """ return batch @@ -602,7 +602,7 @@ class PinaDataModule(LightningDataModule): transferred. :param int dataloader_idx: The index of the dataloader. :return: The batch transferred to the device. - :rtype: list(tuple) + :rtype: list[tuple] """ batch = [ diff --git a/pina/data/dataset.py b/pina/data/dataset.py index ec405e1..b1d6c71 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -175,7 +175,7 @@ class PinaDataset(Dataset, ABC): Return data from the dataset given a list of indices. :param idx: List of indices. - :type idx: list + :type idx: list[int] :return: A dictionary containing the data at the given indices. :rtype: dict """ @@ -216,7 +216,7 @@ class PinaTensorDataset(PinaDataset): :param data: Dictionary containing the data (only torch.Tensor/LableTensor). :type data: dict - :param list(int) idx_list: indices to retrieve. + :param list[int] idx_list: indices to retrieve. :return: Dictionary containing the data at the given indices. :rtype: dict """ @@ -246,7 +246,7 @@ class PinaGraphDataset(PinaDataset): :class:`torch_geometric.data.Data` objects. :param data: List of items to collate in a single batch. - :type data: list(torch_geometric.data.Data) | list(Graph) + :type data: list[torch_geometric.data.Data] | list[Graph] :return: LabelBatch object all the graph collated in a single batch disconnected graphs. :rtype: LabelBatch @@ -256,7 +256,8 @@ class PinaGraphDataset(PinaDataset): def _create_tensor_batch(self, data): """ - Create a torch.Tensor object from a list of torch.Tensor objects. + Reshape properly ``data`` tensor to be processed handle by the graph + based models. :param data: torch.Tensor object of shape (N, ...) where N is the number of data points. @@ -273,7 +274,7 @@ class PinaGraphDataset(PinaDataset): objects. :param data: List of items to collate in a single batch. - :type data: list + :type data: list[torch_geometric.data.Data] | list[Graph] :return: Batch object. :rtype: Batch | PinaBatch """ @@ -288,7 +289,7 @@ class PinaGraphDataset(PinaDataset): Retrieve data from the dataset given a list of indices. :param dict data: Dictionary containing the data. - :param list idx_list: List of indices to retrieve. + :param list[int] idx_list: List of indices to retrieve. :return: Dictionary containing the data at the given indices. :rtype: dict """