Dev Update (#582)

* Fix adaptive refinement (#571)


---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>

* Remove collector

* Fixes

* Fixes

* rm unnecessary comment

* fix advection (#581)

* Fix tutorial .html link (#580)

* fix problem data collection for v0.1 (#584)

* Message Passing Module (#516)

* add deep tensor network block

* add interaction network block

* add radial field network block

* add schnet block

* add equivariant network block

* fix + tests + doc files

* fix egnn + equivariance/invariance tests

Co-authored-by: Dario Coscia <dariocos99@gmail.com>

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>

* add type checker (#527)

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
Dario Coscia
2025-06-13 17:34:37 +02:00
committed by GitHub
parent 6b355b45de
commit 7bf7d34d0f
40 changed files with 1963 additions and 581 deletions

View File

@@ -1,45 +1,58 @@
import pytest
from torch.nn import MSELoss
from pina.solver import PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.callback import R3Refinement
from pina.callback.refinement import R3Refinement
# make the problem
poisson_problem = Poisson()
boundaries = ["g1", "g2", "g3", "g4"]
n = 10
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
poisson_problem.discretise_domain(n, "grid", domains="D")
poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"])
poisson_problem.discretise_domain(10, "grid", domains="D")
model = FeedForward(
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
)
# make the solver
solver = PINN(problem=poisson_problem, model=model)
# def test_r3constructor():
# R3Refinement(sample_every=10)
def test_constructor():
# good constructor
R3Refinement(sample_every=10)
R3Refinement(sample_every=10, residual_loss=MSELoss)
R3Refinement(sample_every=10, condition_to_update=["D"])
# wrong constructor
with pytest.raises(ValueError):
R3Refinement(sample_every="str")
with pytest.raises(ValueError):
R3Refinement(sample_every=10, condition_to_update=3)
# def test_r3refinment_routine():
# # make the trainer
# trainer = Trainer(solver=solver,
# callback=[R3Refinement(sample_every=1)],
# accelerator='cpu',
# max_epochs=5)
# trainer.train()
# def test_r3refinment_routine():
# model = FeedForward(len(poisson_problem.input_variables),
# len(poisson_problem.output_variables))
# solver = PINN(problem=poisson_problem, model=model)
# trainer = Trainer(solver=solver,
# callback=[R3Refinement(sample_every=1)],
# accelerator='cpu',
# max_epochs=5)
# before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
# trainer.train()
# after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
# assert before_n_points == after_n_points
@pytest.mark.parametrize(
"condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]]
)
def test_sample(condition_to_update):
trainer = Trainer(
solver=solver,
callbacks=[
R3Refinement(
sample_every=1, condition_to_update=condition_to_update
)
],
accelerator="cpu",
max_epochs=5,
)
before_n_points = {
loc: len(trainer.solver.problem.input_pts[loc])
for loc in condition_to_update
}
trainer.train()
after_n_points = {
loc: len(trainer.data_module.train_dataset.input[loc])
for loc in condition_to_update
}
assert before_n_points == trainer.callbacks[0].initial_population_size
assert before_n_points == after_n_points