adaptive refinement callback (#299)

Fixed problem of non-constant number of points
This commit is contained in:
Michele Alessi
2024-05-21 09:51:13 +02:00
committed by GitHub
parent a72ce67873
commit 5f89968805
2 changed files with 24 additions and 30 deletions

View File

@@ -75,13 +75,15 @@ def test_r3refinment_routine():
max_epochs=5)
trainer.train()
def test_r3refinment_routine_double_precision():
def test_r3refinment_routine():
model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables))
solver = PINN(problem=poisson_problem, model=model)
trainer = Trainer(solver=solver,
precision='64-true',
callbacks=[R3Refinement(sample_every=1)],
accelerator='cpu',
callbacks=[R3Refinement(sample_every=2)],
max_epochs=5)
before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
trainer.train()
after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
assert before_n_points == after_n_points