Files
PINA/tests/test_callback/test_adaptive_refinement_callback.py
Dario Coscia 7bf7d34d0f Dev Update (#582)
* Fix adaptive refinement (#571)


---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>

* Remove collector

* Fixes

* Fixes

* rm unnecessary comment

* fix advection (#581)

* Fix tutorial .html link (#580)

* fix problem data collection for v0.1 (#584)

* Message Passing Module (#516)

* add deep tensor network block

* add interaction network block

* add radial field network block

* add schnet block

* add equivariant network block

* fix + tests + doc files

* fix egnn + equivariance/invariance tests

Co-authored-by: Dario Coscia <dariocos99@gmail.com>

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>

* add type checker (#527)

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
2025-06-13 17:34:37 +02:00

59 lines
1.7 KiB
Python

import pytest
from torch.nn import MSELoss
from pina.solver import PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.callback.refinement import R3Refinement
# make the problem
poisson_problem = Poisson()
poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"])
poisson_problem.discretise_domain(10, "grid", domains="D")
model = FeedForward(
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
)
solver = PINN(problem=poisson_problem, model=model)
def test_constructor():
# good constructor
R3Refinement(sample_every=10)
R3Refinement(sample_every=10, residual_loss=MSELoss)
R3Refinement(sample_every=10, condition_to_update=["D"])
# wrong constructor
with pytest.raises(ValueError):
R3Refinement(sample_every="str")
with pytest.raises(ValueError):
R3Refinement(sample_every=10, condition_to_update=3)
@pytest.mark.parametrize(
"condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]]
)
def test_sample(condition_to_update):
trainer = Trainer(
solver=solver,
callbacks=[
R3Refinement(
sample_every=1, condition_to_update=condition_to_update
)
],
accelerator="cpu",
max_epochs=5,
)
before_n_points = {
loc: len(trainer.solver.problem.input_pts[loc])
for loc in condition_to_update
}
trainer.train()
after_n_points = {
loc: len(trainer.data_module.train_dataset.input[loc])
for loc in condition_to_update
}
assert before_n_points == trainer.callbacks[0].initial_population_size
assert before_n_points == after_n_points