From 5c509067719b8f3e0a220ddbeae384e99aabd275 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Thu, 4 Apr 2024 14:41:49 +0200 Subject: [PATCH] R3Refinment double precision training fix (#277) * r3 ref double precision fix * fix label missing --------- Co-authored-by: Dario Coscia Co-authored-by: Dario Coscia --- .../callbacks/adaptive_refinment_callbacks.py | 43 +++++++++++++------ .../test_adaptive_refinment_callbacks.py | 11 +++++ 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index b5e2b70..2a6fc19 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -1,8 +1,8 @@ """PINA Callbacks Implementations""" -# from lightning.pytorch.callbacks import Callback -from pytorch_lightning.callbacks import Callback import torch +from pytorch_lightning.callbacks import Callback +from ..label_tensor import LabelTensor from ..utils import check_consistency @@ -12,19 +12,22 @@ class R3Refinement(Callback): """ PINA Implementation of an R3 Refinement Callback. - This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search. - The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those - with low residuals. Points are sampled uniformly in all regions where sampling is needed. + This callback implements the R3 (Retain-Resample-Release) routine for + sampling new points based on adaptive search. + The algorithm incrementally accumulates collocation points in regions + of high PDE residuals, and releases those + with low residuals. Points are sampled uniformly in all regions + where sampling is needed. .. seealso:: - Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks + Original Reference: Daw, Arka, et al. *Mitigating Propagation + Failures in Physics-informed Neural Networks using Retain-Resample-Release (R3) Sampling. (2023)*. DOI: `10.48550/arXiv.2207.02338 `_ :param int sample_every: Frequency for sampling. - :raises ValueError: If `sample_every` is not an integer. Example: @@ -47,6 +50,17 @@ class R3Refinement(Callback): # extract the solver and device from trainer solver = trainer._model device = trainer._accelerator_connector._accelerator_flag + precision = trainer.precision + if precision == "64-true": + precision = torch.float64 + elif precision == "32-true": + precision = torch.float32 + else: + raise RuntimeError("Currently R3Refinement is only implemented " + "for precision '32-true' and '64-true', set " + "Trainer precision to match one of the " + "available precisions.") + # compute residual res_loss = {} @@ -55,10 +69,10 @@ class R3Refinement(Callback): condition = solver.problem.conditions[location] pts = solver.problem.input_pts[location] # send points to correct device - pts = pts.to(device) + pts = pts.to(device=device, dtype=precision) pts = pts.requires_grad_(True) pts.retain_grad() - # PINN loss: equation evaluated only on locations where sampling is needed + # PINN loss: equation evaluated only for sampling locations target = condition.equation.residual(pts, solver.forward(pts)) res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) tot_loss.append(torch.abs(target)) @@ -86,13 +100,14 @@ class R3Refinement(Callback): for location in self._sampling_locations: pts = trainer._model.problem.input_pts[location] labels = pts.labels - pts = pts.cpu().detach() + pts = pts.cpu().detach().as_subclass(torch.Tensor) residuals = res_loss[location].cpu() mask = (residuals > avg).flatten() if any( mask ): # if there are residuals greater than averge we append them - pts = pts[mask] # TODO masking remove labels + # Fix the issue, masking remove labels + pts = (pts[mask]).as_subclass(LabelTensor) pts.labels = labels old_pts[location] = pts tot_points += len(pts) @@ -122,7 +137,8 @@ class R3Refinement(Callback): """ Callback function called at the start of training. - This method extracts the locations for sampling from the problem conditions and calculates the total population. + This method extracts the locations for sampling from the problem + conditions and calculates the total population. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer @@ -151,7 +167,8 @@ class R3Refinement(Callback): """ Callback function called at the end of each training epoch. - This method triggers the R3 routine for refinement if the current epoch is a multiple of `_sample_every`. + This method triggers the R3 routine for refinement if the current + epoch is a multiple of `_sample_every`. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer diff --git a/tests/test_callbacks/test_adaptive_refinment_callbacks.py b/tests/test_callbacks/test_adaptive_refinment_callbacks.py index 8964e46..fb74367 100644 --- a/tests/test_callbacks/test_adaptive_refinment_callbacks.py +++ b/tests/test_callbacks/test_adaptive_refinment_callbacks.py @@ -73,3 +73,14 @@ def test_r3refinment_routine(): callbacks=[R3Refinement(sample_every=1)], max_epochs=5) trainer.train() + +def test_r3refinment_routine_double_precision(): + model = FeedForward(len(poisson_problem.input_variables), + len(poisson_problem.output_variables)) + solver = PINN(problem=poisson_problem, model=model) + trainer = Trainer(solver=solver, + precision='64-true', + accelerator='cpu', + callbacks=[R3Refinement(sample_every=2)], + max_epochs=5) + trainer.train()