fix tests and modules
This commit is contained in:
@@ -159,7 +159,11 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
|
||||
[(input_tensor, output_tensor), (input_graph, output_graph)],
|
||||
)
|
||||
@pytest.mark.parametrize("automatic_batching", [True, False])
|
||||
def test_dataloader(input_, output_, automatic_batching):
|
||||
@pytest.mark.parametrize("batch_size", [None, 10])
|
||||
@pytest.mark.parametrize("batching_mode", ["common_batch_size", "propotional"])
|
||||
def test_dataloader(
|
||||
input_, output_, automatic_batching, batch_size, batching_mode
|
||||
):
|
||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
||||
trainer = Trainer(
|
||||
@@ -169,7 +173,7 @@ def test_dataloader(input_, output_, automatic_batching):
|
||||
val_size=0.3,
|
||||
test_size=0.0,
|
||||
automatic_batching=automatic_batching,
|
||||
common_batch_size=True,
|
||||
batching_mode=batching_mode,
|
||||
)
|
||||
dm = trainer.data_module
|
||||
dm.setup()
|
||||
@@ -187,7 +191,7 @@ def test_dataloader(input_, output_, automatic_batching):
|
||||
|
||||
dataloader = dm.val_dataloader()
|
||||
assert isinstance(dataloader, PinaDataLoader)
|
||||
assert len(dataloader) == 3
|
||||
assert len(dataloader) == 3 if batch_size is not None else 1
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
if isinstance(input_, list):
|
||||
@@ -225,7 +229,7 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
||||
val_size=0.3,
|
||||
test_size=0.0,
|
||||
automatic_batching=automatic_batching,
|
||||
common_batch_size=True,
|
||||
# common_batch_size=True,
|
||||
)
|
||||
dm = trainer.data_module
|
||||
dm.setup()
|
||||
|
||||
@@ -117,6 +117,10 @@ def test_solver_train(use_lt, batch_size, compile):
|
||||
assert isinstance(solver.model, OptimizedModule)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_solver_train(use_lt=True, batch_size=20, compile=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||
@pytest.mark.parametrize("use_lt", [True, False])
|
||||
def test_solver_train_graph(batch_size, use_lt):
|
||||
|
||||
Reference in New Issue
Block a user