59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import pytest
|
|
|
|
from torch.nn import MSELoss
|
|
|
|
from pina.solver import PINN
|
|
from pina.trainer import Trainer
|
|
from pina.model import FeedForward
|
|
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
|
from pina.callback.refinement import R3Refinement
|
|
|
|
|
|
# make the problem
|
|
poisson_problem = Poisson()
|
|
poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"])
|
|
poisson_problem.discretise_domain(10, "grid", domains="D")
|
|
model = FeedForward(
|
|
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
|
|
)
|
|
solver = PINN(problem=poisson_problem, model=model)
|
|
|
|
|
|
def test_constructor():
|
|
# good constructor
|
|
R3Refinement(sample_every=10)
|
|
R3Refinement(sample_every=10, residual_loss=MSELoss)
|
|
R3Refinement(sample_every=10, condition_to_update=["D"])
|
|
# wrong constructor
|
|
with pytest.raises(ValueError):
|
|
R3Refinement(sample_every="str")
|
|
with pytest.raises(ValueError):
|
|
R3Refinement(sample_every=10, condition_to_update=3)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]]
|
|
)
|
|
def test_sample(condition_to_update):
|
|
trainer = Trainer(
|
|
solver=solver,
|
|
callbacks=[
|
|
R3Refinement(
|
|
sample_every=1, condition_to_update=condition_to_update
|
|
)
|
|
],
|
|
accelerator="cpu",
|
|
max_epochs=5,
|
|
)
|
|
before_n_points = {
|
|
loc: len(trainer.solver.problem.input_pts[loc])
|
|
for loc in condition_to_update
|
|
}
|
|
trainer.train()
|
|
after_n_points = {
|
|
loc: len(trainer.data_module.train_dataset[loc].input)
|
|
for loc in condition_to_update
|
|
}
|
|
assert before_n_points == trainer.callbacks[0].initial_population_size
|
|
assert before_n_points == after_n_points
|