* Adding a test for all PINN solvers to assert that the metrics are correctly log

* Adding test for Metric Tracker
* Modify Metric Tracker to correctly log metrics
This commit is contained in:
dario-coscia
2024-08-06 11:12:19 +02:00
committed by Nicola Demo
parent d00fb95d6e
commit 0fa4e1e58a
10 changed files with 308 additions and 9 deletions

View File

@@ -163,6 +163,17 @@ def test_train_cpu():
accelerator='cpu', batch_size=20)
trainer.train()
def test_log():
poisson_problem.discretise_domain(100)
solver = PINN(problem = poisson_problem, model=model, loss=LpLoss())
trainer = Trainer(solver, max_epochs=2, accelerator='cpu')
trainer.train()
# assert the logged metrics are correct
logged_metrics = sorted(list(trainer.logged_metrics.keys()))
total_metrics = sorted(
list([key + '_loss' for key in poisson_problem.conditions.keys()])
+ ['mean_loss'])
assert logged_metrics == total_metrics
def test_train_restore():
tmpdir = "tests/tmp_restore"