fix tests and modules

This commit is contained in:
FilippoOlivo
2025-11-14 16:52:10 +01:00
parent 8440a672a7
commit 43163fdf74
5 changed files with 47 additions and 33 deletions

View File

@@ -159,7 +159,11 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
[(input_tensor, output_tensor), (input_graph, output_graph)],
)
@pytest.mark.parametrize("automatic_batching", [True, False])
def test_dataloader(input_, output_, automatic_batching):
@pytest.mark.parametrize("batch_size", [None, 10])
@pytest.mark.parametrize("batching_mode", ["common_batch_size", "propotional"])
def test_dataloader(
input_, output_, automatic_batching, batch_size, batching_mode
):
problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(
@@ -169,7 +173,7 @@ def test_dataloader(input_, output_, automatic_batching):
val_size=0.3,
test_size=0.0,
automatic_batching=automatic_batching,
common_batch_size=True,
batching_mode=batching_mode,
)
dm = trainer.data_module
dm.setup()
@@ -187,7 +191,7 @@ def test_dataloader(input_, output_, automatic_batching):
dataloader = dm.val_dataloader()
assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 3
assert len(dataloader) == 3 if batch_size is not None else 1
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, list):
@@ -225,7 +229,7 @@ def test_dataloader_labels(input_, output_, automatic_batching):
val_size=0.3,
test_size=0.0,
automatic_batching=automatic_batching,
common_batch_size=True,
# common_batch_size=True,
)
dm = trainer.data_module
dm.setup()

View File

@@ -117,6 +117,10 @@ def test_solver_train(use_lt, batch_size, compile):
assert isinstance(solver.model, OptimizedModule)
if __name__ == "__main__":
test_solver_train(use_lt=True, batch_size=20, compile=True)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_train_graph(batch_size, use_lt):