fix tests and modules
This commit is contained in:
@@ -47,7 +47,7 @@ class DummyDataloader:
|
||||
idx.append(i)
|
||||
i += world_size
|
||||
else:
|
||||
idx = [i for i in range(len(dataset))]
|
||||
idx = list(range(len(dataset)))
|
||||
self.dataset = dataset.getitem_from_list(idx)
|
||||
self.device = device
|
||||
self.dataset = (
|
||||
@@ -158,15 +158,25 @@ class PinaDataLoader:
|
||||
batch_size,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
common_batch_size=True,
|
||||
separate_conditions=False,
|
||||
batching_mode="common_batch_size",
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Initialize the PinaDataLoader.
|
||||
:param dict dataset_dict: A dictionary mapping dataset names to their
|
||||
respective PinaDataset instances.
|
||||
:param int batch_size: The batch size for the dataloader.
|
||||
:param int num_workers: Number of worker processes for data loading.
|
||||
:param bool shuffle: Whether to shuffle the data at every epoch.
|
||||
:param str batching_mode: The batching mode to use. Options are
|
||||
"common_batch_size", "separate_conditions", and "proportional".
|
||||
:param device: The device to which the data should be moved.
|
||||
"""
|
||||
self.dataset_dict = dataset_dict
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
self.separate_conditions = separate_conditions
|
||||
self.batching_mode = batching_mode.lower()
|
||||
self.device = device
|
||||
|
||||
# Batch size None means we want to load the entire dataset in a single
|
||||
@@ -177,13 +187,13 @@ class PinaDataLoader:
|
||||
}
|
||||
else:
|
||||
# Compute batch size per dataset
|
||||
if common_batch_size: # all datasets have the same batch size
|
||||
if batching_mode in ["common_batch_size", "separate_conditions"]:
|
||||
# (the sum of the batch sizes is equal to
|
||||
# n_conditions * batch_size)
|
||||
batch_size_per_dataset = {
|
||||
split: batch_size for split in dataset_dict.keys()
|
||||
}
|
||||
else: # batch size proportional to dataset size (the sum of the
|
||||
elif batching_mode == "propotional":
|
||||
# batch sizes is equal to the specified batch size)
|
||||
batch_size_per_dataset = self._compute_batch_size()
|
||||
|
||||
@@ -242,6 +252,12 @@ class PinaDataLoader:
|
||||
def _create_dataloader(self, dataset, batch_size):
|
||||
"""
|
||||
Create the dataloader for the given dataset.
|
||||
|
||||
:param PinaDataset dataset: The dataset for which to create the
|
||||
dataloader.
|
||||
:param int batch_size: The batch size for the dataloader.
|
||||
:return: The created dataloader.
|
||||
:rtype: :class:`torch.utils.data.DataLoader`
|
||||
"""
|
||||
# If batch size is None, use DummyDataloader
|
||||
if batch_size is None or batch_size >= len(dataset):
|
||||
@@ -270,7 +286,7 @@ class PinaDataLoader:
|
||||
"""
|
||||
# If separate conditions, return sum of lengths of all dataloaders
|
||||
# else, return max length among dataloaders
|
||||
if self.separate_conditions:
|
||||
if self.batching_mode == "separate_conditions":
|
||||
return sum(len(dl) for dl in self.dataloaders.values())
|
||||
return max(len(dl) for dl in self.dataloaders.values())
|
||||
|
||||
@@ -280,7 +296,7 @@ class PinaDataLoader:
|
||||
:return: Yields batches from the dataloader.
|
||||
:rtype: dict
|
||||
"""
|
||||
if self.separate_conditions:
|
||||
if self.batching_mode == "separate_conditions":
|
||||
for split, dl in self.dataloaders.items():
|
||||
for batch in dl:
|
||||
yield {split: batch}
|
||||
|
||||
Reference in New Issue
Block a user