* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import pytest
|
|
|
|
from torch.nn import MSELoss
|
|
|
|
from pina.solver import PINN
|
|
from pina.trainer import Trainer
|
|
from pina.model import FeedForward
|
|
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
|
from pina.callback.refinement import R3Refinement
|
|
|
|
|
|
# make the problem
|
|
poisson_problem = Poisson()
|
|
poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"])
|
|
poisson_problem.discretise_domain(10, "grid", domains="D")
|
|
model = FeedForward(
|
|
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
|
|
)
|
|
solver = PINN(problem=poisson_problem, model=model)
|
|
|
|
|
|
def test_constructor():
|
|
# good constructor
|
|
R3Refinement(sample_every=10)
|
|
R3Refinement(sample_every=10, residual_loss=MSELoss)
|
|
R3Refinement(sample_every=10, condition_to_update=["D"])
|
|
# wrong constructor
|
|
with pytest.raises(ValueError):
|
|
R3Refinement(sample_every="str")
|
|
with pytest.raises(ValueError):
|
|
R3Refinement(sample_every=10, condition_to_update=3)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]]
|
|
)
|
|
def test_sample(condition_to_update):
|
|
trainer = Trainer(
|
|
solver=solver,
|
|
callbacks=[
|
|
R3Refinement(
|
|
sample_every=1, condition_to_update=condition_to_update
|
|
)
|
|
],
|
|
accelerator="cpu",
|
|
max_epochs=5,
|
|
)
|
|
before_n_points = {
|
|
loc: len(trainer.solver.problem.input_pts[loc])
|
|
for loc in condition_to_update
|
|
}
|
|
trainer.train()
|
|
after_n_points = {
|
|
loc: len(trainer.data_module.train_dataset.input[loc])
|
|
for loc in condition_to_update
|
|
}
|
|
assert before_n_points == trainer.callbacks[0].initial_population_size
|
|
assert before_n_points == after_n_points
|