From ef29f0a95da1807e5a4f38f42d61d046ef651405 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Tue, 25 Mar 2025 12:18:27 +0100 Subject: [PATCH] Bug fix and add additional tests for Dataset and DataModule (#517) --- pina/data/data_module.py | 14 ++--- pina/data/dataset.py | 22 ++++++- tests/test_data/test_data_module.py | 91 +++++++++++++++++++++++++++ tests/test_data/test_graph_dataset.py | 32 +++++++++- 4 files changed, 143 insertions(+), 16 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 349d74d..566f646 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -217,12 +217,11 @@ class PinaSampler: parameter and the environment in which the code is running. """ - def __new__(cls, dataset, shuffle): + def __new__(cls, dataset): """ Instantiate and initialize the sampler. :param PinaDataset dataset: The dataset from which to sample. - :param bool shuffle: Whether to shuffle the dataset. :return: The sampler instance. :rtype: :class:`torch.utils.data.Sampler` """ @@ -231,12 +230,9 @@ class PinaSampler: torch.distributed.is_available() and torch.distributed.is_initialized() ): - sampler = DistributedSampler(dataset, shuffle=shuffle) + sampler = DistributedSampler(dataset) else: - if shuffle: - sampler = RandomSampler(dataset) - else: - sampler = SequentialSampler(dataset) + sampler = SequentialSampler(dataset) return sampler @@ -496,8 +492,6 @@ class PinaDataModule(LightningDataModule): :return: The dataloader for the given split. :rtype: torch.utils.data.DataLoader """ - - shuffle = self.shuffle if split == "train" else False # Suppress the warning about num_workers. # In many cases, especially for PINNs, # serial data loading can outperform parallel data loading. @@ -511,7 +505,7 @@ class PinaDataModule(LightningDataModule): ) # Use custom batching (good if batch size is large) if self.batch_size is not None: - sampler = PinaSampler(dataset, shuffle) + sampler = PinaSampler(dataset) if self.automatic_batching: collate = Collator( self.find_max_conditions_lengths(split), diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 54c1556..8d58be4 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -167,9 +167,15 @@ class PinaDataset(Dataset, ABC): :return: A dictionary containing all the data in the dataset. :rtype: dict """ - - index = list(range(len(self))) - return self.fetch_from_idx_list(index) + to_return_dict = {} + for condition, data in self.conditions_dict.items(): + len_condition = len( + data["input"] + ) # Length of the current condition + to_return_dict[condition] = self._retrive_data( + data, list(range(len_condition)) + ) # Retrieve the data from the current condition + return to_return_dict def fetch_from_idx_list(self, idx): """ @@ -306,3 +312,13 @@ class PinaGraphDataset(PinaDataset): ) for k, v in data.items() } + + @property + def input(self): + """ + Return the input data for the dataset. + + :return: Dictionary containing the input points. + :rtype: dict + """ + return {k: v["input"] for k, v in self.conditions_dict.items()} diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index fe7b3eb..53e7334 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -238,3 +238,94 @@ def test_dataloader_labels(input_, output_, automatic_batching): assert data["data"]["input"].labels == ["u", "v", "w"] assert isinstance(data["data"]["target"], torch.Tensor) assert data["data"]["target"].labels == ["u", "v", "w"] + + +def test_get_all_data(): + input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) + target = input + + problem = SupervisedProblem(input, target) + datamodule = PinaDataModule( + problem, + train_size=0.7, + test_size=0.2, + val_size=0.1, + batch_size=64, + shuffle=False, + repeat=False, + automatic_batching=None, + num_workers=0, + pin_memory=False, + ) + datamodule.setup("fit") + datamodule.setup("test") + assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700 + assert torch.isclose( + datamodule.train_dataset.get_all_data()["data"]["input"], input[:700] + ).all() + assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100 + assert torch.isclose( + datamodule.val_dataset.get_all_data()["data"]["input"], input[900:] + ).all() + assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200 + assert torch.isclose( + datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900] + ).all() + + +def test_input_propery_tensor(): + input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) + target = input + + problem = SupervisedProblem(input, target) + datamodule = PinaDataModule( + problem, + train_size=0.7, + test_size=0.2, + val_size=0.1, + batch_size=64, + shuffle=False, + repeat=False, + automatic_batching=None, + num_workers=0, + pin_memory=False, + ) + datamodule.setup("fit") + datamodule.setup("test") + input_ = datamodule.input + assert isinstance(input_, dict) + assert isinstance(input_["train"], dict) + assert isinstance(input_["val"], dict) + assert isinstance(input_["test"], dict) + assert torch.isclose(input_["train"]["data"], input[:700]).all() + assert torch.isclose(input_["val"]["data"], input[900:]).all() + assert torch.isclose(input_["test"]["data"], input[700:900]).all() + + +def test_input_propery_graph(): + problem = SupervisedProblem(input_graph, output_graph) + datamodule = PinaDataModule( + problem, + train_size=0.7, + test_size=0.2, + val_size=0.1, + batch_size=64, + shuffle=False, + repeat=False, + automatic_batching=None, + num_workers=0, + pin_memory=False, + ) + datamodule.setup("fit") + datamodule.setup("test") + input_ = datamodule.input + assert isinstance(input_, dict) + assert isinstance(input_["train"], dict) + assert isinstance(input_["val"], dict) + assert isinstance(input_["test"], dict) + assert isinstance(input_["train"]["data"], list) + assert isinstance(input_["val"]["data"], list) + assert isinstance(input_["test"]["data"], list) + assert len(input_["train"]["data"]) == 70 + assert len(input_["val"]["data"]) == 10 + assert len(input_["test"]["data"]) == 20 diff --git a/tests/test_data/test_graph_dataset.py b/tests/test_data/test_graph_dataset.py index 1fe0c89..a49b0ad 100644 --- a/tests/test_data/test_graph_dataset.py +++ b/tests/test_data/test_graph_dataset.py @@ -31,7 +31,7 @@ conditions_dict_single = { max_conditions_lengths_single = {"data": 100} # Problem with multiple conditions -conditions_dict_single_multi = { +conditions_dict_multi = { "data_1": { "input": input_, "target": output_, @@ -49,7 +49,7 @@ max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} "conditions_dict, max_conditions_lengths", [ (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_single_multi, max_conditions_lengths_multi), + (conditions_dict_multi, max_conditions_lengths_multi), ], ) def test_constructor(conditions_dict, max_conditions_lengths): @@ -66,7 +66,7 @@ def test_constructor(conditions_dict, max_conditions_lengths): "conditions_dict, max_conditions_lengths", [ (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_single_multi, max_conditions_lengths_multi), + (conditions_dict_multi, max_conditions_lengths_multi), ], ) def test_getitem(conditions_dict, max_conditions_lengths): @@ -110,3 +110,29 @@ def test_getitem(conditions_dict, max_conditions_lengths): ] ) assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()]) + + +def test_input_single_condition(): + dataset = PinaDatasetFactory( + conditions_dict_single, + max_conditions_lengths=max_conditions_lengths_single, + automatic_batching=True, + ) + input_ = dataset.input + assert isinstance(input_, dict) + assert isinstance(input_["data"], list) + assert all([isinstance(d, Data) for d in input_["data"]]) + + +def test_input_multi_condition(): + dataset = PinaDatasetFactory( + conditions_dict_multi, + max_conditions_lengths=max_conditions_lengths_multi, + automatic_batching=True, + ) + input_ = dataset.input + assert isinstance(input_, dict) + assert isinstance(input_["data_1"], list) + assert all([isinstance(d, Data) for d in input_["data_1"]]) + assert isinstance(input_["data_2"], list) + assert all([isinstance(d, Data) for d in input_["data_2"]])