diff --git a/pina/data/data_module.py b/pina/data/data_module.py index ff4405b..20b3c1c 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -1,4 +1,5 @@ import logging +import warnings from lightning.pytorch import LightningDataModule import torch from ..label_tensor import LabelTensor @@ -8,6 +9,7 @@ from torch.utils.data.distributed import DistributedSampler from .dataset import PinaDatasetFactory from ..collector import Collector + class DummyDataloader: """" Dummy dataloader used when batch size is None. It callects all the data @@ -57,7 +59,7 @@ class Collator: self.max_conditions_lengths = max_conditions_lengths self.callable_function = self._collate_custom_dataloader if \ max_conditions_lengths is None else ( - self._collate_standard_dataloader) + self._collate_standard_dataloader) self.dataset = dataset def _collate_custom_dataloader(self, batch): @@ -95,7 +97,7 @@ class Collator: class PinaSampler: - def __new__(self, dataset, batch_size, shuffle, automatic_batching): + def __new__(cls, dataset, shuffle): if (torch.distributed.is_available() and torch.distributed.is_initialized()): @@ -123,15 +125,35 @@ class PinaDataModule(LightningDataModule): batch_size=None, shuffle=True, repeat=False, - automatic_batching=False + automatic_batching=False, + num_workers=0, + pin_memory=False, ): """ - Initialize the object, creating dataset based on input problem - :param problem: Problem where data are defined - :param train_size: number/percentage of elements in train split - :param test_size: number/percentage of elements in test split - :param val_size: number/percentage of elements in evaluation split - :param batch_size: batch size used for training + Initialize the object, creating datasets based on the input problem. + + :param problem: The problem defining the dataset. + :type problem: AbstractProblem + :param train_size: Fraction or number of elements in the training split. + :type train_size: float + :param test_size: Fraction or number of elements in the test split. + :type test_size: float + :param val_size: Fraction or number of elements in the validation split. + :type val_size: float + :param predict_size: Fraction or number of elements in the prediction split. + :type predict_size: float + :param batch_size: Batch size used for training. If None, the entire dataset is used per batch. + :type batch_size: int or None + :param shuffle: Whether to shuffle the dataset before splitting. + :type shuffle: bool + :param repeat: Whether to repeat the dataset indefinitely. + :type repeat: bool + :param automatic_batching: Whether to enable automatic batching. + :type automatic_batching: bool + :param num_workers: Number of worker threads for data loading. Default 0 (serial loading) + :type num_workers: int + :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False) + :type pin_memory: bool """ logging.debug('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule') @@ -170,6 +192,15 @@ class PinaDataModule(LightningDataModule): collector = Collector(problem) collector.store_fixed_data() collector.store_sample_domains() + if batch_size is None and num_workers != 0: + warnings.warn( + "Setting num_workers when batch_size is None has no effect on " + "the DataLoading process.") + if batch_size is None and pin_memory: + warnings.warn("Setting pin_memory to True has no effect when " + "batch_size is None.") + self.num_workers = num_workers + self.pin_memory = pin_memory self.collector_splits = self._create_splits(collector, splits_dict) self.transfer_batch_to_device = self._transfer_batch_to_device @@ -271,20 +302,27 @@ class PinaDataModule(LightningDataModule): dataset_dict[key].update({condition_name: data}) return dataset_dict - def _create_dataloader(self, split, dataset): shuffle = self.shuffle if split == 'train' else False + # Suppress the warning about num_workers. + # In many cases, especially for PINNs, serial data loading can outperform parallel data loading. + warnings.filterwarnings( + "ignore", + message=( + r"The '(train|val|test)_dataloader' does not have many workers which may be a bottleneck."), + module="lightning.pytorch.trainer.connectors.data_connector" + ) # Use custom batching (good if batch size is large) if self.batch_size is not None: - sampler = PinaSampler(dataset, self.batch_size, - shuffle, self.automatic_batching) + sampler = PinaSampler(dataset, shuffle) if self.automatic_batching: collate = Collator(self.find_max_conditions_lengths(split)) else: collate = Collator(None, dataset) return DataLoader(dataset, self.batch_size, - collate_fn=collate, sampler=sampler) + collate_fn=collate, sampler=sampler, + num_workers=self.num_workers) dataloader = DummyDataloader(dataset) dataloader.dataset = self._transfer_batch_to_device( dataloader.dataset, self.trainer.strategy.root_device, 0) diff --git a/pina/trainer.py b/pina/trainer.py index bfd334d..0d15e76 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -18,6 +18,8 @@ class Trainer(lightning.pytorch.Trainer): predict_size=0., compile=None, automatic_batching=None, + num_workers=None, + pin_memory=None, **kwargs): """ PINA Trainer class for costumizing every aspect of training via flags. @@ -44,6 +46,10 @@ class Trainer(lightning.pytorch.Trainer): performed. Please avoid using automatic batching when batch_size is large, default False. :type automatic_batching: bool + :param num_workers: Number of worker threads for data loading. Default 0 (serial loading) + :type num_workers: int + :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False) + :type pin_memory: bool :Keyword Arguments: The additional keyword arguments specify the training setup @@ -60,6 +66,14 @@ class Trainer(lightning.pytorch.Trainer): check_consistency(automatic_batching, bool) if compile is not None: check_consistency(compile, bool) + if pin_memory is not None: + check_consistency(pin_memory, bool) + else: + pin_memory = False + if num_workers is not None: + check_consistency(pin_memory, int) + else: + num_workers = 0 if train_size + test_size + val_size + predict_size > 1: raise ValueError('train_size, test_size, val_size and predict_size ' 'must sum up to 1.') @@ -93,19 +107,16 @@ class Trainer(lightning.pytorch.Trainer): compile = False if automatic_batching is None: automatic_batching = False - + # set attributes self.compile = compile - self.automatic_batching = automatic_batching - self.train_size = train_size - self.test_size = test_size - self.val_size = val_size - self.predict_size = predict_size self.solver = solver self.batch_size = batch_size self._move_to_device() self.data_module = None - self._create_loader() + self._create_datamodule(train_size, test_size, val_size, predict_size, + batch_size, automatic_batching, pin_memory, + num_workers) # logging self.logging_kwargs = { @@ -127,7 +138,15 @@ class Trainer(lightning.pytorch.Trainer): pb.unknown_parameters[key] = torch.nn.Parameter( pb.unknown_parameters[key].data.to(device)) - def _create_loader(self): + def _create_datamodule(self, + train_size, + test_size, + val_size, + predict_size, + batch_size, + automatic_batching, + pin_memory, + num_workers): """ This method is used here because is resampling is needed during training, there is no need to define to touch the @@ -136,8 +155,8 @@ class Trainer(lightning.pytorch.Trainer): if not self.solver.problem.are_all_domains_discretised: error_message = '\n'.join([ f"""{" " * 13} ---> Domain {key} { - "sampled" if key in self.solver.problem.discretised_domains else - "not sampled"}""" for key in + "sampled" if key in self.solver.problem.discretised_domains else + "not sampled"}""" for key in self.solver.problem.domains.keys() ]) raise RuntimeError('Cannot create Trainer if not all conditions ' @@ -145,12 +164,14 @@ class Trainer(lightning.pytorch.Trainer): f'{error_message}') self.data_module = PinaDataModule( self.solver.problem, - train_size=self.train_size, - test_size=self.test_size, - val_size=self.val_size, - predict_size=self.predict_size, - batch_size=self.batch_size, - automatic_batching=self.automatic_batching) + train_size=train_size, + test_size=test_size, + val_size=val_size, + predict_size=predict_size, + batch_size=batch_size, + automatic_batching=automatic_batching, + num_workers=num_workers, + pin_memory=pin_memory) def train(self, **kwargs): """ diff --git a/tests/test_data/test_datamodule.py b/tests/test_data/test_datamodule.py new file mode 100644 index 0000000..866eebc --- /dev/null +++ b/tests/test_data/test_datamodule.py @@ -0,0 +1,178 @@ +import torch +import pytest +from pina.data import PinaDataModule +from pina.data.dataset import PinaTensorDataset, PinaGraphDataset +from pina.problem.zoo import SupervisedProblem +from pina.graph import RadiusGraph +from pina.data.data_module import DummyDataloader +from pina import Trainer +from pina.solvers import SupervisedSolver +from torch_geometric.data import Batch +from torch.utils.data import DataLoader + +input_tensor = torch.rand((100, 10)) +output_tensor = torch.rand((100, 2)) + +x = torch.rand((100, 50 , 10)) +pos = torch.rand((100, 50 , 2)) +input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True) +output_graph = torch.rand((100, 50 , 10)) + + +@pytest.mark.parametrize( + "input_, output_", + [ + (input_tensor, output_tensor), + (input_graph, output_graph) + ] +) +def test_constructor(input_, output_): + problem = SupervisedProblem(input_=input_, output_=output_) + PinaDataModule(problem) + +@pytest.mark.parametrize( + "input_, output_", + [ + (input_tensor, output_tensor), + (input_graph, output_graph) + ] +) +@pytest.mark.parametrize( + "train_size, val_size, test_size", + [ + (.7, .2, .1), + (.7, .3, 0) + ] +) +def test_setup_train(input_, output_, train_size, val_size, test_size): + problem = SupervisedProblem(input_=input_, output_=output_) + dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size) + dm.setup() + assert hasattr(dm, "train_dataset") + if isinstance(input_, torch.Tensor): + assert isinstance(dm.train_dataset, PinaTensorDataset) + else: + assert isinstance(dm.train_dataset, PinaGraphDataset) + #assert len(dm.train_dataset) == int(len(input_) * train_size) + if test_size > 0: + assert hasattr(dm, "test_dataset") + assert dm.test_dataset is None + else: + assert not hasattr(dm, "test_dataset") + assert hasattr(dm, "val_dataset") + if isinstance(input_, torch.Tensor): + assert isinstance(dm.val_dataset, PinaTensorDataset) + else: + assert isinstance(dm.val_dataset, PinaGraphDataset) + #assert len(dm.val_dataset) == int(len(input_) * val_size) + +@pytest.mark.parametrize( + "input_, output_", + [ + (input_tensor, output_tensor), + (input_graph, output_graph) + ] +) +@pytest.mark.parametrize( + "train_size, val_size, test_size", + [ + (.7, .2, .1), + (0., 0., 1.) + ] +) +def test_setup_test(input_, output_, train_size, val_size, test_size): + problem = SupervisedProblem(input_=input_, output_=output_) + dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size) + dm.setup(stage='test') + if train_size > 0: + assert hasattr(dm, "train_dataset") + assert dm.train_dataset is None + else: + assert not hasattr(dm, "train_dataset") + if val_size > 0: + assert hasattr(dm, "val_dataset") + assert dm.val_dataset is None + else: + assert not hasattr(dm, "val_dataset") + + assert hasattr(dm, "test_dataset") + if isinstance(input_, torch.Tensor): + assert isinstance(dm.test_dataset, PinaTensorDataset) + else: + assert isinstance(dm.test_dataset, PinaGraphDataset) + #assert len(dm.test_dataset) == int(len(input_) * test_size) + +@pytest.mark.parametrize( + "input_, output_", + [ + (input_tensor, output_tensor), + (input_graph, output_graph) + ] +) +def test_dummy_dataloader(input_, output_): + problem = SupervisedProblem(input_=input_, output_=output_) + solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) + trainer = Trainer(solver, batch_size=None, train_size=.7, val_size=.3, test_size=0.) + dm = trainer.data_module + dm.setup() + dm.trainer = trainer + dataloader = dm.train_dataloader() + assert isinstance(dataloader, DummyDataloader) + assert len(dataloader) == 1 + data = next(dataloader) + assert isinstance(data, list) + assert isinstance(data[0], tuple) + if isinstance(input_, RadiusGraph): + assert isinstance(data[0][1]['input_points'], Batch) + else: + assert isinstance(data[0][1]['input_points'], torch.Tensor) + assert isinstance(data[0][1]['output_points'], torch.Tensor) + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, DummyDataloader) + assert len(dataloader) == 1 + data = next(dataloader) + assert isinstance(data, list) + assert isinstance(data[0], tuple) + if isinstance(input_, RadiusGraph): + assert isinstance(data[0][1]['input_points'], Batch) + else: + assert isinstance(data[0][1]['input_points'], torch.Tensor) + assert isinstance(data[0][1]['output_points'], torch.Tensor) + +@pytest.mark.parametrize( + "input_, output_", + [ + (input_tensor, output_tensor), + (input_graph, output_graph) + ] +) +def test_dataloader(input_, output_): + problem = SupervisedProblem(input_=input_, output_=output_) + solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) + trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, test_size=0.) + dm = trainer.data_module + dm.setup() + dm.trainer = trainer + dataloader = dm.train_dataloader() + assert isinstance(dataloader, DataLoader) + assert len(dataloader) == 7 + data = next(iter(dataloader)) + assert isinstance(data, dict) + if isinstance(input_, RadiusGraph): + assert isinstance(data['data']['input_points'], Batch) + else: + assert isinstance(data['data']['input_points'], torch.Tensor) + assert isinstance(data['data']['output_points'], torch.Tensor) + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, DataLoader) + assert len(dataloader) == 3 + data = next(iter(dataloader)) + assert isinstance(data, dict) + if isinstance(input_, RadiusGraph): + assert isinstance(data['data']['input_points'], Batch) + else: + assert isinstance(data['data']['input_points'], torch.Tensor) + assert isinstance(data['data']['output_points'], torch.Tensor) + diff --git a/tests/test_data/test_graph_dataset.py b/tests/test_data/test_graph_dataset.py new file mode 100644 index 0000000..15da4bf --- /dev/null +++ b/tests/test_data/test_graph_dataset.py @@ -0,0 +1,101 @@ +import torch +import pytest +from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset +from pina.graph import KNNGraph +from torch_geometric.data import Data + +x = torch.rand((100, 20, 10)) +pos = torch.rand((100, 20, 2)) +input_ = KNNGraph(x=x, pos=pos, k=3, build_edge_attr=True) +output_ = torch.rand((100, 20, 10)) + +x_2 = torch.rand((50, 20, 10)) +pos_2 = torch.rand((50, 20, 2)) +input_2_ = KNNGraph(x=x_2, pos=pos_2, k=3, build_edge_attr=True) +output_2_ = torch.rand((50, 20, 10)) + + +# Problem with a single condition +conditions_dict_single = { + 'data': { + 'input_points': input_.data, + 'output_points': output_, + } +} +max_conditions_lengths_single = { + 'data': 100 +} + +# Problem with multiple conditions +conditions_dict_single_multi = { + 'data_1': { + 'input_points': input_.data, + 'output_points': output_, + }, + 'data_2': { + 'input_points': input_2_.data, + 'output_points': output_2_, + } +} + +max_conditions_lengths_multi = { + 'data_1': 100, + 'data_2': 50 +} + + +@pytest.mark.parametrize( + "conditions_dict, max_conditions_lengths", + [ + (conditions_dict_single, max_conditions_lengths_single), + (conditions_dict_single_multi, max_conditions_lengths_multi) + ] +) +def test_constructor(conditions_dict, max_conditions_lengths): + dataset = PinaDatasetFactory(conditions_dict, + max_conditions_lengths=max_conditions_lengths, + automatic_batching=True) + assert isinstance(dataset, PinaGraphDataset) + assert len(dataset) == 100 + + +@pytest.mark.parametrize( + "conditions_dict, max_conditions_lengths", + [ + (conditions_dict_single, max_conditions_lengths_single), + (conditions_dict_single_multi, max_conditions_lengths_multi) + ] +) +def test_getitem(conditions_dict, max_conditions_lengths): + dataset = PinaDatasetFactory(conditions_dict, + max_conditions_lengths=max_conditions_lengths, + automatic_batching=True) + data = dataset[50] + assert isinstance(data, dict) + assert all([isinstance(d['input_points'], Data) + for d in data.values()]) + assert all([isinstance(d['output_points'], torch.Tensor) + for d in data.values()]) + assert all([d['input_points'].x.shape == torch.Size((20, 10)) + for d in data.values()]) + assert all([d['output_points'].shape == torch.Size((20, 10)) + for d in data.values()]) + assert all([d['input_points'].edge_index.shape == + torch.Size((2, 60)) for d in data.values()]) + assert all([d['input_points'].edge_attr.shape[0] + == 60 for d in data.values()]) + + data = dataset.fetch_from_idx_list([i for i in range(20)]) + assert isinstance(data, dict) + assert all([isinstance(d['input_points'], Data) + for d in data.values()]) + assert all([isinstance(d['output_points'], torch.Tensor) + for d in data.values()]) + assert all([d['input_points'].x.shape == torch.Size((400, 10)) + for d in data.values()]) + assert all([d['output_points'].shape == torch.Size((400, 10)) + for d in data.values()]) + assert all([d['input_points'].edge_index.shape == + torch.Size((2, 1200)) for d in data.values()]) + assert all([d['input_points'].edge_attr.shape[0] + == 1200 for d in data.values()]) diff --git a/tests/test_data/test_tensor_dataset.py b/tests/test_data/test_tensor_dataset.py new file mode 100644 index 0000000..230cae4 --- /dev/null +++ b/tests/test_data/test_tensor_dataset.py @@ -0,0 +1,88 @@ +import torch +import pytest +from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset + +input_tensor = torch.rand((100, 10)) +output_tensor = torch.rand((100, 2)) + +input_tensor_2 = torch.rand((50, 10)) +output_tensor_2 = torch.rand((50, 2)) + +conditions_dict_single = { + 'data': { + 'input_points': input_tensor, + 'output_points': output_tensor, + } +} + +conditions_dict_single_multi = { + 'data_1': { + 'input_points': input_tensor, + 'output_points': output_tensor, + }, + 'data_2': { + 'input_points': input_tensor_2, + 'output_points': output_tensor_2, + } +} + +max_conditions_lengths_single = { + 'data': 100 +} + +max_conditions_lengths_multi = { + 'data_1': 100, + 'data_2': 50 +} + + +@pytest.mark.parametrize( + "conditions_dict, max_conditions_lengths", + [ + (conditions_dict_single, max_conditions_lengths_single), + (conditions_dict_single_multi, max_conditions_lengths_multi) + ] +) +def test_constructor_tensor(conditions_dict, max_conditions_lengths): + dataset = PinaDatasetFactory(conditions_dict, + max_conditions_lengths=max_conditions_lengths, + automatic_batching=True) + assert isinstance(dataset, PinaTensorDataset) + + +def test_getitem_single(): + dataset = PinaDatasetFactory(conditions_dict_single, + max_conditions_lengths=max_conditions_lengths_single, + automatic_batching=False) + + tensors = dataset.fetch_from_idx_list([i for i in range(70)]) + assert isinstance(tensors, dict) + assert list(tensors.keys()) == ['data'] + assert sorted(list(tensors['data'].keys())) == [ + 'input_points', 'output_points'] + assert isinstance(tensors['data']['input_points'], torch.Tensor) + assert tensors['data']['input_points'].shape == torch.Size((70, 10)) + assert isinstance(tensors['data']['output_points'], torch.Tensor) + assert tensors['data']['output_points'].shape == torch.Size((70, 2)) + + +def test_getitem_multi(): + dataset = PinaDatasetFactory(conditions_dict_single_multi, + max_conditions_lengths=max_conditions_lengths_multi, + automatic_batching=False) + tensors = dataset.fetch_from_idx_list([i for i in range(70)]) + assert isinstance(tensors, dict) + assert list(tensors.keys()) == ['data_1', 'data_2'] + assert sorted(list(tensors['data_1'].keys())) == [ + 'input_points', 'output_points'] + assert isinstance(tensors['data_1']['input_points'], torch.Tensor) + assert tensors['data_1']['input_points'].shape == torch.Size((70, 10)) + assert isinstance(tensors['data_1']['output_points'], torch.Tensor) + assert tensors['data_1']['output_points'].shape == torch.Size((70, 2)) + + assert sorted(list(tensors['data_2'].keys())) == [ + 'input_points', 'output_points'] + assert isinstance(tensors['data_2']['input_points'], torch.Tensor) + assert tensors['data_2']['input_points'].shape == torch.Size((50, 10)) + assert isinstance(tensors['data_2']['output_points'], torch.Tensor) + assert tensors['data_2']['output_points'].shape == torch.Size((50, 2))