fix
This commit is contained in:
@@ -13,7 +13,7 @@ class DummyDataloader:
|
||||
DataLoader that returns the entire dataset in a single batch.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
def __init__(self, dataset, device=None):
|
||||
"""
|
||||
Prepare a dataloader object that returns the entire dataset in a single
|
||||
batch. Depending on the number of GPUs, the dataset is managed
|
||||
@@ -47,9 +47,14 @@ class DummyDataloader:
|
||||
idx.append(i)
|
||||
i += world_size
|
||||
else:
|
||||
idx = list(range(len(dataset)))
|
||||
|
||||
idx = [i for i in range(len(dataset))]
|
||||
self.dataset = dataset.getitem_from_list(idx)
|
||||
self.device = device
|
||||
self.dataset = (
|
||||
{k: v.to(self.device) for k, v in self.dataset.items()}
|
||||
if self.device
|
||||
else self.dataset
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
@@ -155,12 +160,14 @@ class PinaDataLoader:
|
||||
shuffle=False,
|
||||
common_batch_size=True,
|
||||
separate_conditions=False,
|
||||
device=None,
|
||||
):
|
||||
self.dataset_dict = dataset_dict
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
self.separate_conditions = separate_conditions
|
||||
self.device = device
|
||||
|
||||
# Batch size None means we want to load the entire dataset in a single
|
||||
# batch
|
||||
@@ -238,7 +245,7 @@ class PinaDataLoader:
|
||||
"""
|
||||
# If batch size is None, use DummyDataloader
|
||||
if batch_size is None or batch_size >= len(dataset):
|
||||
return DummyDataloader(dataset)
|
||||
return DummyDataloader(dataset, device=self.device)
|
||||
|
||||
# Determine the appropriate collate function
|
||||
if not dataset.automatic_batching:
|
||||
|
||||
Reference in New Issue
Block a user