Files
PINA/tests/test_callback/test_adaptive_refinement_callback.py
FilippoOlivo 8440a672a7 fix tests
2025-11-13 17:03:31 +01:00

59 lines
1.7 KiB
Python

import pytest
from torch.nn import MSELoss
from pina.solver import PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.callback.refinement import R3Refinement
# make the problem
poisson_problem = Poisson()
poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"])
poisson_problem.discretise_domain(10, "grid", domains="D")
model = FeedForward(
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
)
solver = PINN(problem=poisson_problem, model=model)
def test_constructor():
# good constructor
R3Refinement(sample_every=10)
R3Refinement(sample_every=10, residual_loss=MSELoss)
R3Refinement(sample_every=10, condition_to_update=["D"])
# wrong constructor
with pytest.raises(ValueError):
R3Refinement(sample_every="str")
with pytest.raises(ValueError):
R3Refinement(sample_every=10, condition_to_update=3)
@pytest.mark.parametrize(
"condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]]
)
def test_sample(condition_to_update):
trainer = Trainer(
solver=solver,
callbacks=[
R3Refinement(
sample_every=1, condition_to_update=condition_to_update
)
],
accelerator="cpu",
max_epochs=5,
)
before_n_points = {
loc: len(trainer.solver.problem.input_pts[loc])
for loc in condition_to_update
}
trainer.train()
after_n_points = {
loc: len(trainer.data_module.train_dataset[loc].input)
for loc in condition_to_update
}
assert before_n_points == trainer.callbacks[0].initial_population_size
assert before_n_points == after_n_points