Bug fix and add additional tests for Dataset and DataModule (#517)

This commit is contained in:
Filippo Olivo
2025-03-25 12:18:27 +01:00
committed by GitHub
parent 03ef90c358
commit ef29f0a95d
4 changed files with 143 additions and 16 deletions

View File

@@ -217,12 +217,11 @@ class PinaSampler:
parameter and the environment in which the code is running.
"""
def __new__(cls, dataset, shuffle):
def __new__(cls, dataset):
"""
Instantiate and initialize the sampler.
:param PinaDataset dataset: The dataset from which to sample.
:param bool shuffle: Whether to shuffle the dataset.
:return: The sampler instance.
:rtype: :class:`torch.utils.data.Sampler`
"""
@@ -231,12 +230,9 @@ class PinaSampler:
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
sampler = DistributedSampler(dataset, shuffle=shuffle)
sampler = DistributedSampler(dataset)
else:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
sampler = SequentialSampler(dataset)
return sampler
@@ -496,8 +492,6 @@ class PinaDataModule(LightningDataModule):
:return: The dataloader for the given split.
:rtype: torch.utils.data.DataLoader
"""
shuffle = self.shuffle if split == "train" else False
# Suppress the warning about num_workers.
# In many cases, especially for PINNs,
# serial data loading can outperform parallel data loading.
@@ -511,7 +505,7 @@ class PinaDataModule(LightningDataModule):
)
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
sampler = PinaSampler(dataset, shuffle)
sampler = PinaSampler(dataset)
if self.automatic_batching:
collate = Collator(
self.find_max_conditions_lengths(split),