Bug fix and add additional tests for Dataset and DataModule (#517)
This commit is contained in:
committed by
FilippoOlivo
parent
79a7199985
commit
80c257da4d
@@ -217,12 +217,11 @@ class PinaSampler:
|
||||
parameter and the environment in which the code is running.
|
||||
"""
|
||||
|
||||
def __new__(cls, dataset, shuffle):
|
||||
def __new__(cls, dataset):
|
||||
"""
|
||||
Instantiate and initialize the sampler.
|
||||
|
||||
:param PinaDataset dataset: The dataset from which to sample.
|
||||
:param bool shuffle: Whether to shuffle the dataset.
|
||||
:return: The sampler instance.
|
||||
:rtype: :class:`torch.utils.data.Sampler`
|
||||
"""
|
||||
@@ -231,12 +230,9 @@ class PinaSampler:
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
):
|
||||
sampler = DistributedSampler(dataset, shuffle=shuffle)
|
||||
sampler = DistributedSampler(dataset)
|
||||
else:
|
||||
if shuffle:
|
||||
sampler = RandomSampler(dataset)
|
||||
else:
|
||||
sampler = SequentialSampler(dataset)
|
||||
sampler = SequentialSampler(dataset)
|
||||
return sampler
|
||||
|
||||
|
||||
@@ -496,8 +492,6 @@ class PinaDataModule(LightningDataModule):
|
||||
:return: The dataloader for the given split.
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
|
||||
shuffle = self.shuffle if split == "train" else False
|
||||
# Suppress the warning about num_workers.
|
||||
# In many cases, especially for PINNs,
|
||||
# serial data loading can outperform parallel data loading.
|
||||
@@ -511,7 +505,7 @@ class PinaDataModule(LightningDataModule):
|
||||
)
|
||||
# Use custom batching (good if batch size is large)
|
||||
if self.batch_size is not None:
|
||||
sampler = PinaSampler(dataset, shuffle)
|
||||
sampler = PinaSampler(dataset)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(
|
||||
self.find_max_conditions_lengths(split),
|
||||
|
||||
@@ -167,9 +167,15 @@ class PinaDataset(Dataset, ABC):
|
||||
:return: A dictionary containing all the data in the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
index = list(range(len(self)))
|
||||
return self.fetch_from_idx_list(index)
|
||||
to_return_dict = {}
|
||||
for condition, data in self.conditions_dict.items():
|
||||
len_condition = len(
|
||||
data["input"]
|
||||
) # Length of the current condition
|
||||
to_return_dict[condition] = self._retrive_data(
|
||||
data, list(range(len_condition))
|
||||
) # Retrieve the data from the current condition
|
||||
return to_return_dict
|
||||
|
||||
def fetch_from_idx_list(self, idx):
|
||||
"""
|
||||
@@ -306,3 +312,13 @@ class PinaGraphDataset(PinaDataset):
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
Return the input data for the dataset.
|
||||
|
||||
:return: Dictionary containing the input points.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||
|
||||
Reference in New Issue
Block a user