fix tests and modules

This commit is contained in:
FilippoOlivo
2025-11-14 16:52:10 +01:00
parent 8440a672a7
commit 43163fdf74
5 changed files with 47 additions and 33 deletions

View File

@@ -27,8 +27,7 @@ class PinaDataModule(LightningDataModule):
val_size=0.1, val_size=0.1,
batch_size=None, batch_size=None,
shuffle=True, shuffle=True,
common_batch_size=True, batching_mode="common_batch_size",
separate_conditions=False,
automatic_batching=None, automatic_batching=None,
num_workers=0, num_workers=0,
pin_memory=False, pin_memory=False,
@@ -84,8 +83,7 @@ class PinaDataModule(LightningDataModule):
# Store fixed attributes # Store fixed attributes
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
self.common_batch_size = common_batch_size self.batching_mode = batching_mode
self.separate_conditions = separate_conditions
self.automatic_batching = automatic_batching self.automatic_batching = automatic_batching
# If batch size is None, num_workers has no effect # If batch size is None, num_workers has no effect
@@ -280,8 +278,7 @@ class PinaDataModule(LightningDataModule):
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=self.shuffle, shuffle=self.shuffle,
num_workers=self.num_workers, num_workers=self.num_workers,
common_batch_size=self.common_batch_size, batching_mode=self.batching_mode,
separate_conditions=self.separate_conditions,
device=self.trainer.strategy.root_device, device=self.trainer.strategy.root_device,
) )
if self.batch_size is None: if self.batch_size is None:
@@ -330,7 +327,7 @@ class PinaDataModule(LightningDataModule):
:rtype: list[tuple] :rtype: list[tuple]
""" """
return [(k, v) for k, v in batch.items()] return list(batch.items())
def _transfer_batch_to_device(self, batch, device, dataloader_idx): def _transfer_batch_to_device(self, batch, device, dataloader_idx):
""" """

View File

@@ -47,7 +47,7 @@ class DummyDataloader:
idx.append(i) idx.append(i)
i += world_size i += world_size
else: else:
idx = [i for i in range(len(dataset))] idx = list(range(len(dataset)))
self.dataset = dataset.getitem_from_list(idx) self.dataset = dataset.getitem_from_list(idx)
self.device = device self.device = device
self.dataset = ( self.dataset = (
@@ -158,15 +158,25 @@ class PinaDataLoader:
batch_size, batch_size,
num_workers=0, num_workers=0,
shuffle=False, shuffle=False,
common_batch_size=True, batching_mode="common_batch_size",
separate_conditions=False,
device=None, device=None,
): ):
"""
Initialize the PinaDataLoader.
:param dict dataset_dict: A dictionary mapping dataset names to their
respective PinaDataset instances.
:param int batch_size: The batch size for the dataloader.
:param int num_workers: Number of worker processes for data loading.
:param bool shuffle: Whether to shuffle the data at every epoch.
:param str batching_mode: The batching mode to use. Options are
"common_batch_size", "separate_conditions", and "proportional".
:param device: The device to which the data should be moved.
"""
self.dataset_dict = dataset_dict self.dataset_dict = dataset_dict
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
self.shuffle = shuffle self.shuffle = shuffle
self.separate_conditions = separate_conditions self.batching_mode = batching_mode.lower()
self.device = device self.device = device
# Batch size None means we want to load the entire dataset in a single # Batch size None means we want to load the entire dataset in a single
@@ -177,13 +187,13 @@ class PinaDataLoader:
} }
else: else:
# Compute batch size per dataset # Compute batch size per dataset
if common_batch_size: # all datasets have the same batch size if batching_mode in ["common_batch_size", "separate_conditions"]:
# (the sum of the batch sizes is equal to # (the sum of the batch sizes is equal to
# n_conditions * batch_size) # n_conditions * batch_size)
batch_size_per_dataset = { batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys() split: batch_size for split in dataset_dict.keys()
} }
else: # batch size proportional to dataset size (the sum of the elif batching_mode == "propotional":
# batch sizes is equal to the specified batch size) # batch sizes is equal to the specified batch size)
batch_size_per_dataset = self._compute_batch_size() batch_size_per_dataset = self._compute_batch_size()
@@ -242,6 +252,12 @@ class PinaDataLoader:
def _create_dataloader(self, dataset, batch_size): def _create_dataloader(self, dataset, batch_size):
""" """
Create the dataloader for the given dataset. Create the dataloader for the given dataset.
:param PinaDataset dataset: The dataset for which to create the
dataloader.
:param int batch_size: The batch size for the dataloader.
:return: The created dataloader.
:rtype: :class:`torch.utils.data.DataLoader`
""" """
# If batch size is None, use DummyDataloader # If batch size is None, use DummyDataloader
if batch_size is None or batch_size >= len(dataset): if batch_size is None or batch_size >= len(dataset):
@@ -270,7 +286,7 @@ class PinaDataLoader:
""" """
# If separate conditions, return sum of lengths of all dataloaders # If separate conditions, return sum of lengths of all dataloaders
# else, return max length among dataloaders # else, return max length among dataloaders
if self.separate_conditions: if self.batching_mode == "separate_conditions":
return sum(len(dl) for dl in self.dataloaders.values()) return sum(len(dl) for dl in self.dataloaders.values())
return max(len(dl) for dl in self.dataloaders.values()) return max(len(dl) for dl in self.dataloaders.values())
@@ -280,7 +296,7 @@ class PinaDataLoader:
:return: Yields batches from the dataloader. :return: Yields batches from the dataloader.
:rtype: dict :rtype: dict
""" """
if self.separate_conditions: if self.batching_mode == "separate_conditions":
for split, dl in self.dataloaders.items(): for split, dl in self.dataloaders.items():
for batch in dl: for batch in dl:
yield {split: batch} yield {split: batch}

View File

@@ -31,8 +31,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=0.0, test_size=0.0,
val_size=0.0, val_size=0.0,
compile=None, compile=None,
common_batch_size=True, batching_mode="common_batch_size",
separate_conditions=False,
automatic_batching=None, automatic_batching=None,
num_workers=None, num_workers=None,
pin_memory=None, pin_memory=None,
@@ -85,8 +84,7 @@ class Trainer(lightning.pytorch.Trainer):
train_size=train_size, train_size=train_size,
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
common_batch_size=common_batch_size, batching_mode=batching_mode,
seperate_conditions=separate_conditions,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
compile=compile, compile=compile,
) )
@@ -141,8 +139,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
batch_size=batch_size, batch_size=batch_size,
common_batch_size=common_batch_size, batching_mode=batching_mode,
seperate_conditions=separate_conditions,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
pin_memory=pin_memory, pin_memory=pin_memory,
num_workers=num_workers, num_workers=num_workers,
@@ -180,8 +177,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size, test_size,
val_size, val_size,
batch_size, batch_size,
common_batch_size, batching_mode,
seperate_conditions,
automatic_batching, automatic_batching,
pin_memory, pin_memory,
num_workers, num_workers,
@@ -233,8 +229,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
batch_size=batch_size, batch_size=batch_size,
common_batch_size=common_batch_size, batching_mode=batching_mode,
separate_conditions=seperate_conditions,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
@@ -286,8 +281,7 @@ class Trainer(lightning.pytorch.Trainer):
train_size, train_size,
test_size, test_size,
val_size, val_size,
common_batch_size, batching_mode,
seperate_conditions,
automatic_batching, automatic_batching,
compile, compile,
): ):
@@ -314,8 +308,7 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(train_size, float) check_consistency(train_size, float)
check_consistency(test_size, float) check_consistency(test_size, float)
check_consistency(val_size, float) check_consistency(val_size, float)
check_consistency(common_batch_size, bool) check_consistency(batching_mode, str)
check_consistency(seperate_conditions, bool)
if automatic_batching is not None: if automatic_batching is not None:
check_consistency(automatic_batching, bool) check_consistency(automatic_batching, bool)
if compile is not None: if compile is not None:

View File

@@ -159,7 +159,11 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
[(input_tensor, output_tensor), (input_graph, output_graph)], [(input_tensor, output_tensor), (input_graph, output_graph)],
) )
@pytest.mark.parametrize("automatic_batching", [True, False]) @pytest.mark.parametrize("automatic_batching", [True, False])
def test_dataloader(input_, output_, automatic_batching): @pytest.mark.parametrize("batch_size", [None, 10])
@pytest.mark.parametrize("batching_mode", ["common_batch_size", "propotional"])
def test_dataloader(
input_, output_, automatic_batching, batch_size, batching_mode
):
problem = SupervisedProblem(input_=input_, output_=output_) problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer( trainer = Trainer(
@@ -169,7 +173,7 @@ def test_dataloader(input_, output_, automatic_batching):
val_size=0.3, val_size=0.3,
test_size=0.0, test_size=0.0,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
common_batch_size=True, batching_mode=batching_mode,
) )
dm = trainer.data_module dm = trainer.data_module
dm.setup() dm.setup()
@@ -187,7 +191,7 @@ def test_dataloader(input_, output_, automatic_batching):
dataloader = dm.val_dataloader() dataloader = dm.val_dataloader()
assert isinstance(dataloader, PinaDataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 3 assert len(dataloader) == 3 if batch_size is not None else 1
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
if isinstance(input_, list): if isinstance(input_, list):
@@ -225,7 +229,7 @@ def test_dataloader_labels(input_, output_, automatic_batching):
val_size=0.3, val_size=0.3,
test_size=0.0, test_size=0.0,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
common_batch_size=True, # common_batch_size=True,
) )
dm = trainer.data_module dm = trainer.data_module
dm.setup() dm.setup()

View File

@@ -117,6 +117,10 @@ def test_solver_train(use_lt, batch_size, compile):
assert isinstance(solver.model, OptimizedModule) assert isinstance(solver.model, OptimizedModule)
if __name__ == "__main__":
test_solver_train(use_lt=True, batch_size=20, compile=True)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("use_lt", [True, False])
def test_solver_train_graph(batch_size, use_lt): def test_solver_train_graph(batch_size, use_lt):