R3Refinment double precision training fix (#277)

* r3 ref double precision fix
* fix label missing


---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.Home>
This commit is contained in:
Dario Coscia
2024-04-04 14:41:49 +02:00
committed by GitHub
parent 56d5f3627b
commit 5c50906771
2 changed files with 41 additions and 13 deletions

View File

@@ -73,3 +73,14 @@ def test_r3refinment_routine():
callbacks=[R3Refinement(sample_every=1)],
max_epochs=5)
trainer.train()
def test_r3refinment_routine_double_precision():
model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables))
solver = PINN(problem=poisson_problem, model=model)
trainer = Trainer(solver=solver,
precision='64-true',
accelerator='cpu',
callbacks=[R3Refinement(sample_every=2)],
max_epochs=5)
trainer.train()