fix tests

This commit is contained in:
FilippoOlivo
2025-11-13 17:03:31 +01:00
parent 0ee63686dd
commit 8440a672a7
5 changed files with 289 additions and 300 deletions

View File

@@ -142,14 +142,10 @@ def test_setup(solver, fn, stage, apply_to):
for cond in ["data1", "data2"]:
scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][
apply_to
]
trainer_copy.data_module.train_dataset[cond].data[apply_to]
)
shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][
apply_to
]
trainer_copy.data_module.train_dataset[cond].data[apply_to]
)
assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond]
@@ -158,8 +154,8 @@ def test_setup(solver, fn, stage, apply_to):
for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to]
current_points = dataset[cond].data[apply_to]
old_points = old_dataset[cond].data[apply_to]
expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected)
@@ -204,10 +200,10 @@ def test_setup_pinn(fn, stage, apply_to):
cond = "data"
scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
trainer_copy.data_module.train_dataset[cond].data[apply_to]
)
shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
trainer_copy.data_module.train_dataset[cond].data[apply_to]
)
assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond]
@@ -216,8 +212,8 @@ def test_setup_pinn(fn, stage, apply_to):
for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to]
current_points = dataset[cond].data[apply_to]
old_points = old_dataset[cond].data[apply_to]
expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected)
@@ -242,3 +238,7 @@ def test_setup_graph_dataset():
)
with pytest.raises(NotImplementedError):
trainer.train()
# if __name__ == "__main__":
# test_setup(supervised_solver_lt, [torch.std, torch.mean], "all", "input")