fix doc
This commit is contained in:
@@ -163,6 +163,7 @@ class PinaDataLoader:
|
||||
):
|
||||
"""
|
||||
Initialize the PinaDataLoader.
|
||||
|
||||
:param dict dataset_dict: A dictionary mapping dataset names to their
|
||||
respective PinaDataset instances.
|
||||
:param int batch_size: The batch size for the dataloader.
|
||||
@@ -172,6 +173,7 @@ class PinaDataLoader:
|
||||
"common_batch_size", "separate_conditions", and "proportional".
|
||||
:param device: The device to which the data should be moved.
|
||||
"""
|
||||
|
||||
self.dataset_dict = dataset_dict
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
@@ -209,6 +211,7 @@ class PinaDataLoader:
|
||||
"""
|
||||
Compute an appropriate batch size for the given dataset.
|
||||
"""
|
||||
|
||||
# Compute number of elements per dataset
|
||||
elements_per_dataset = {
|
||||
dataset_name: len(dataset)
|
||||
@@ -281,6 +284,7 @@ class PinaDataLoader:
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the length of the dataloader.
|
||||
|
||||
:return: The length of the dataloader.
|
||||
:rtype: int
|
||||
"""
|
||||
@@ -293,6 +297,7 @@ class PinaDataLoader:
|
||||
def __iter__(self):
|
||||
"""
|
||||
Iterate over the dataloader.
|
||||
|
||||
:return: Yields batches from the dataloader.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user