fix tests
This commit is contained in:
@@ -51,7 +51,7 @@ def test_sample(condition_to_update):
|
||||
}
|
||||
trainer.train()
|
||||
after_n_points = {
|
||||
loc: len(trainer.data_module.train_dataset.input[loc])
|
||||
loc: len(trainer.data_module.train_dataset[loc].input)
|
||||
for loc in condition_to_update
|
||||
}
|
||||
assert before_n_points == trainer.callbacks[0].initial_population_size
|
||||
|
||||
@@ -142,14 +142,10 @@ def test_setup(solver, fn, stage, apply_to):
|
||||
|
||||
for cond in ["data1", "data2"]:
|
||||
scale = scale_fn(
|
||||
trainer_copy.data_module.train_dataset.conditions_dict[cond][
|
||||
apply_to
|
||||
]
|
||||
trainer_copy.data_module.train_dataset[cond].data[apply_to]
|
||||
)
|
||||
shift = shift_fn(
|
||||
trainer_copy.data_module.train_dataset.conditions_dict[cond][
|
||||
apply_to
|
||||
]
|
||||
trainer_copy.data_module.train_dataset[cond].data[apply_to]
|
||||
)
|
||||
assert "scale" in normalizer[cond]
|
||||
assert "shift" in normalizer[cond]
|
||||
@@ -158,8 +154,8 @@ def test_setup(solver, fn, stage, apply_to):
|
||||
for ds_name in stage_map[stage]:
|
||||
dataset = getattr(trainer.data_module, ds_name, None)
|
||||
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
|
||||
current_points = dataset.conditions_dict[cond][apply_to]
|
||||
old_points = old_dataset.conditions_dict[cond][apply_to]
|
||||
current_points = dataset[cond].data[apply_to]
|
||||
old_points = old_dataset[cond].data[apply_to]
|
||||
expected = (old_points - shift) / scale
|
||||
assert torch.allclose(current_points, expected)
|
||||
|
||||
@@ -204,10 +200,10 @@ def test_setup_pinn(fn, stage, apply_to):
|
||||
cond = "data"
|
||||
|
||||
scale = scale_fn(
|
||||
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
|
||||
trainer_copy.data_module.train_dataset[cond].data[apply_to]
|
||||
)
|
||||
shift = shift_fn(
|
||||
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
|
||||
trainer_copy.data_module.train_dataset[cond].data[apply_to]
|
||||
)
|
||||
assert "scale" in normalizer[cond]
|
||||
assert "shift" in normalizer[cond]
|
||||
@@ -216,8 +212,8 @@ def test_setup_pinn(fn, stage, apply_to):
|
||||
for ds_name in stage_map[stage]:
|
||||
dataset = getattr(trainer.data_module, ds_name, None)
|
||||
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
|
||||
current_points = dataset.conditions_dict[cond][apply_to]
|
||||
old_points = old_dataset.conditions_dict[cond][apply_to]
|
||||
current_points = dataset[cond].data[apply_to]
|
||||
old_points = old_dataset[cond].data[apply_to]
|
||||
expected = (old_points - shift) / scale
|
||||
assert torch.allclose(current_points, expected)
|
||||
|
||||
@@ -242,3 +238,7 @@ def test_setup_graph_dataset():
|
||||
)
|
||||
with pytest.raises(NotImplementedError):
|
||||
trainer.train()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_setup(supervised_solver_lt, [torch.std, torch.mean], "all", "input")
|
||||
|
||||
Reference in New Issue
Block a user