Files
PINA/tests/test_callback/test_metric_tracker.py
Dario Coscia 632934f9cc Update callbacks and tests (#482)
---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
2025-03-19 17:48:25 +01:00

41 lines
1.1 KiB
Python

from pina.solver import PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.callback import MetricTracker
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
# make the problem
poisson_problem = Poisson()
boundaries = ["g1", "g2", "g3", "g4"]
n = 10
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
poisson_problem.discretise_domain(n, "grid", domains="D")
model = FeedForward(
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
)
# make the solver
solver = PINN(problem=poisson_problem, model=model)
def test_metric_tracker_constructor():
MetricTracker()
def test_metric_tracker_routine():
# make the trainer
trainer = Trainer(
solver=solver,
callbacks=[MetricTracker()],
accelerator="cpu",
max_epochs=5,
log_every_n_steps=1,
)
trainer.train()
# get the tracked metrics
metrics = trainer.callbacks[0].metrics
# assert the logged metrics are correct
logged_metrics = sorted(list(metrics.keys()))
assert logged_metrics == ["train_loss"]