This commit is contained in:
FilippoOlivo
2025-11-13 17:03:18 +01:00
parent 51a0399111
commit 0ee63686dd
3 changed files with 56 additions and 17 deletions

View File

@@ -13,7 +13,7 @@ class DummyDataloader:
DataLoader that returns the entire dataset in a single batch.
"""
def __init__(self, dataset):
def __init__(self, dataset, device=None):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
@@ -47,9 +47,14 @@ class DummyDataloader:
idx.append(i)
i += world_size
else:
idx = list(range(len(dataset)))
idx = [i for i in range(len(dataset))]
self.dataset = dataset.getitem_from_list(idx)
self.device = device
self.dataset = (
{k: v.to(self.device) for k, v in self.dataset.items()}
if self.device
else self.dataset
)
def __iter__(self):
"""
@@ -155,12 +160,14 @@ class PinaDataLoader:
shuffle=False,
common_batch_size=True,
separate_conditions=False,
device=None,
):
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.separate_conditions = separate_conditions
self.device = device
# Batch size None means we want to load the entire dataset in a single
# batch
@@ -238,7 +245,7 @@ class PinaDataLoader:
"""
# If batch size is None, use DummyDataloader
if batch_size is None or batch_size >= len(dataset):
return DummyDataloader(dataset)
return DummyDataloader(dataset, device=self.device)
# Determine the appropriate collate function
if not dataset.automatic_batching: