* Adding a test for all PINN solvers to assert that the metrics are correctly log
* Adding test for Metric Tracker * Modify Metric Tracker to correctly log metrics
This commit is contained in:
committed by
Nicola Demo
parent
d00fb95d6e
commit
0fa4e1e58a
@@ -163,6 +163,17 @@ def test_train_cpu():
|
||||
accelerator='cpu', batch_size=20)
|
||||
trainer.train()
|
||||
|
||||
def test_log():
|
||||
poisson_problem.discretise_domain(100)
|
||||
solver = PINN(problem = poisson_problem, model=model, loss=LpLoss())
|
||||
trainer = Trainer(solver, max_epochs=2, accelerator='cpu')
|
||||
trainer.train()
|
||||
# assert the logged metrics are correct
|
||||
logged_metrics = sorted(list(trainer.logged_metrics.keys()))
|
||||
total_metrics = sorted(
|
||||
list([key + '_loss' for key in poisson_problem.conditions.keys()])
|
||||
+ ['mean_loss'])
|
||||
assert logged_metrics == total_metrics
|
||||
|
||||
def test_train_restore():
|
||||
tmpdir = "tests/tmp_restore"
|
||||
|
||||
Reference in New Issue
Block a user