Dev Update (#582)
* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
@@ -1,45 +1,58 @@
|
||||
import pytest
|
||||
|
||||
from torch.nn import MSELoss
|
||||
|
||||
from pina.solver import PINN
|
||||
from pina.trainer import Trainer
|
||||
from pina.model import FeedForward
|
||||
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||
from pina.callback import R3Refinement
|
||||
from pina.callback.refinement import R3Refinement
|
||||
|
||||
|
||||
# make the problem
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ["g1", "g2", "g3", "g4"]
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
|
||||
poisson_problem.discretise_domain(n, "grid", domains="D")
|
||||
poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"])
|
||||
poisson_problem.discretise_domain(10, "grid", domains="D")
|
||||
model = FeedForward(
|
||||
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
|
||||
)
|
||||
|
||||
# make the solver
|
||||
solver = PINN(problem=poisson_problem, model=model)
|
||||
|
||||
|
||||
# def test_r3constructor():
|
||||
# R3Refinement(sample_every=10)
|
||||
def test_constructor():
|
||||
# good constructor
|
||||
R3Refinement(sample_every=10)
|
||||
R3Refinement(sample_every=10, residual_loss=MSELoss)
|
||||
R3Refinement(sample_every=10, condition_to_update=["D"])
|
||||
# wrong constructor
|
||||
with pytest.raises(ValueError):
|
||||
R3Refinement(sample_every="str")
|
||||
with pytest.raises(ValueError):
|
||||
R3Refinement(sample_every=10, condition_to_update=3)
|
||||
|
||||
|
||||
# def test_r3refinment_routine():
|
||||
# # make the trainer
|
||||
# trainer = Trainer(solver=solver,
|
||||
# callback=[R3Refinement(sample_every=1)],
|
||||
# accelerator='cpu',
|
||||
# max_epochs=5)
|
||||
# trainer.train()
|
||||
|
||||
# def test_r3refinment_routine():
|
||||
# model = FeedForward(len(poisson_problem.input_variables),
|
||||
# len(poisson_problem.output_variables))
|
||||
# solver = PINN(problem=poisson_problem, model=model)
|
||||
# trainer = Trainer(solver=solver,
|
||||
# callback=[R3Refinement(sample_every=1)],
|
||||
# accelerator='cpu',
|
||||
# max_epochs=5)
|
||||
# before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
|
||||
# trainer.train()
|
||||
# after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
|
||||
# assert before_n_points == after_n_points
|
||||
@pytest.mark.parametrize(
|
||||
"condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]]
|
||||
)
|
||||
def test_sample(condition_to_update):
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[
|
||||
R3Refinement(
|
||||
sample_every=1, condition_to_update=condition_to_update
|
||||
)
|
||||
],
|
||||
accelerator="cpu",
|
||||
max_epochs=5,
|
||||
)
|
||||
before_n_points = {
|
||||
loc: len(trainer.solver.problem.input_pts[loc])
|
||||
for loc in condition_to_update
|
||||
}
|
||||
trainer.train()
|
||||
after_n_points = {
|
||||
loc: len(trainer.data_module.train_dataset.input[loc])
|
||||
for loc in condition_to_update
|
||||
}
|
||||
assert before_n_points == trainer.callbacks[0].initial_population_size
|
||||
assert before_n_points == after_n_points
|
||||
|
||||
Reference in New Issue
Block a user