diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 9a0cf0a..603e810 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -27,8 +27,7 @@ class PinaDataModule(LightningDataModule): val_size=0.1, batch_size=None, shuffle=True, - common_batch_size=True, - separate_conditions=False, + batching_mode="common_batch_size", automatic_batching=None, num_workers=0, pin_memory=False, @@ -84,8 +83,7 @@ class PinaDataModule(LightningDataModule): # Store fixed attributes self.batch_size = batch_size self.shuffle = shuffle - self.common_batch_size = common_batch_size - self.separate_conditions = separate_conditions + self.batching_mode = batching_mode self.automatic_batching = automatic_batching # If batch size is None, num_workers has no effect @@ -280,8 +278,7 @@ class PinaDataModule(LightningDataModule): batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, - common_batch_size=self.common_batch_size, - separate_conditions=self.separate_conditions, + batching_mode=self.batching_mode, device=self.trainer.strategy.root_device, ) if self.batch_size is None: @@ -330,7 +327,7 @@ class PinaDataModule(LightningDataModule): :rtype: list[tuple] """ - return [(k, v) for k, v in batch.items()] + return list(batch.items()) def _transfer_batch_to_device(self, batch, device, dataloader_idx): """ diff --git a/pina/data/dataloader.py b/pina/data/dataloader.py index 6267868..9bfec54 100644 --- a/pina/data/dataloader.py +++ b/pina/data/dataloader.py @@ -47,7 +47,7 @@ class DummyDataloader: idx.append(i) i += world_size else: - idx = [i for i in range(len(dataset))] + idx = list(range(len(dataset))) self.dataset = dataset.getitem_from_list(idx) self.device = device self.dataset = ( @@ -158,15 +158,25 @@ class PinaDataLoader: batch_size, num_workers=0, shuffle=False, - common_batch_size=True, - separate_conditions=False, + batching_mode="common_batch_size", device=None, ): + """ + Initialize the PinaDataLoader. + :param dict dataset_dict: A dictionary mapping dataset names to their + respective PinaDataset instances. + :param int batch_size: The batch size for the dataloader. + :param int num_workers: Number of worker processes for data loading. + :param bool shuffle: Whether to shuffle the data at every epoch. + :param str batching_mode: The batching mode to use. Options are + "common_batch_size", "separate_conditions", and "proportional". + :param device: The device to which the data should be moved. + """ self.dataset_dict = dataset_dict self.batch_size = batch_size self.num_workers = num_workers self.shuffle = shuffle - self.separate_conditions = separate_conditions + self.batching_mode = batching_mode.lower() self.device = device # Batch size None means we want to load the entire dataset in a single @@ -177,13 +187,13 @@ class PinaDataLoader: } else: # Compute batch size per dataset - if common_batch_size: # all datasets have the same batch size + if batching_mode in ["common_batch_size", "separate_conditions"]: # (the sum of the batch sizes is equal to # n_conditions * batch_size) batch_size_per_dataset = { split: batch_size for split in dataset_dict.keys() } - else: # batch size proportional to dataset size (the sum of the + elif batching_mode == "propotional": # batch sizes is equal to the specified batch size) batch_size_per_dataset = self._compute_batch_size() @@ -242,6 +252,12 @@ class PinaDataLoader: def _create_dataloader(self, dataset, batch_size): """ Create the dataloader for the given dataset. + + :param PinaDataset dataset: The dataset for which to create the + dataloader. + :param int batch_size: The batch size for the dataloader. + :return: The created dataloader. + :rtype: :class:`torch.utils.data.DataLoader` """ # If batch size is None, use DummyDataloader if batch_size is None or batch_size >= len(dataset): @@ -270,7 +286,7 @@ class PinaDataLoader: """ # If separate conditions, return sum of lengths of all dataloaders # else, return max length among dataloaders - if self.separate_conditions: + if self.batching_mode == "separate_conditions": return sum(len(dl) for dl in self.dataloaders.values()) return max(len(dl) for dl in self.dataloaders.values()) @@ -280,7 +296,7 @@ class PinaDataLoader: :return: Yields batches from the dataloader. :rtype: dict """ - if self.separate_conditions: + if self.batching_mode == "separate_conditions": for split, dl in self.dataloaders.items(): for batch in dl: yield {split: batch} diff --git a/pina/trainer.py b/pina/trainer.py index 840f2a6..a7e9654 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -31,8 +31,7 @@ class Trainer(lightning.pytorch.Trainer): test_size=0.0, val_size=0.0, compile=None, - common_batch_size=True, - separate_conditions=False, + batching_mode="common_batch_size", automatic_batching=None, num_workers=None, pin_memory=None, @@ -85,8 +84,7 @@ class Trainer(lightning.pytorch.Trainer): train_size=train_size, test_size=test_size, val_size=val_size, - common_batch_size=common_batch_size, - seperate_conditions=separate_conditions, + batching_mode=batching_mode, automatic_batching=automatic_batching, compile=compile, ) @@ -141,8 +139,7 @@ class Trainer(lightning.pytorch.Trainer): test_size=test_size, val_size=val_size, batch_size=batch_size, - common_batch_size=common_batch_size, - seperate_conditions=separate_conditions, + batching_mode=batching_mode, automatic_batching=automatic_batching, pin_memory=pin_memory, num_workers=num_workers, @@ -180,8 +177,7 @@ class Trainer(lightning.pytorch.Trainer): test_size, val_size, batch_size, - common_batch_size, - seperate_conditions, + batching_mode, automatic_batching, pin_memory, num_workers, @@ -233,8 +229,7 @@ class Trainer(lightning.pytorch.Trainer): test_size=test_size, val_size=val_size, batch_size=batch_size, - common_batch_size=common_batch_size, - separate_conditions=seperate_conditions, + batching_mode=batching_mode, automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, @@ -286,8 +281,7 @@ class Trainer(lightning.pytorch.Trainer): train_size, test_size, val_size, - common_batch_size, - seperate_conditions, + batching_mode, automatic_batching, compile, ): @@ -314,8 +308,7 @@ class Trainer(lightning.pytorch.Trainer): check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) - check_consistency(common_batch_size, bool) - check_consistency(seperate_conditions, bool) + check_consistency(batching_mode, str) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index c17a4ed..ed87e62 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -159,7 +159,11 @@ def test_setup_test(input_, output_, train_size, val_size, test_size): [(input_tensor, output_tensor), (input_graph, output_graph)], ) @pytest.mark.parametrize("automatic_batching", [True, False]) -def test_dataloader(input_, output_, automatic_batching): +@pytest.mark.parametrize("batch_size", [None, 10]) +@pytest.mark.parametrize("batching_mode", ["common_batch_size", "propotional"]) +def test_dataloader( + input_, output_, automatic_batching, batch_size, batching_mode +): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) trainer = Trainer( @@ -169,7 +173,7 @@ def test_dataloader(input_, output_, automatic_batching): val_size=0.3, test_size=0.0, automatic_batching=automatic_batching, - common_batch_size=True, + batching_mode=batching_mode, ) dm = trainer.data_module dm.setup() @@ -187,7 +191,7 @@ def test_dataloader(input_, output_, automatic_batching): dataloader = dm.val_dataloader() assert isinstance(dataloader, PinaDataLoader) - assert len(dataloader) == 3 + assert len(dataloader) == 3 if batch_size is not None else 1 data = next(iter(dataloader)) assert isinstance(data, dict) if isinstance(input_, list): @@ -225,7 +229,7 @@ def test_dataloader_labels(input_, output_, automatic_batching): val_size=0.3, test_size=0.0, automatic_batching=automatic_batching, - common_batch_size=True, + # common_batch_size=True, ) dm = trainer.data_module dm.setup() diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 6f7d1ab..5d9709b 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -117,6 +117,10 @@ def test_solver_train(use_lt, batch_size, compile): assert isinstance(solver.model, OptimizedModule) +if __name__ == "__main__": + test_solver_train(use_lt=True, batch_size=20, compile=True) + + @pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("use_lt", [True, False]) def test_solver_train_graph(batch_size, use_lt):