R3Refinment double precision training fix (#277)

* r3 ref double precision fix
* fix label missing


---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.Home>
This commit is contained in:
Dario Coscia
2024-04-04 14:41:49 +02:00
committed by GitHub
parent 56d5f3627b
commit 5c50906771
2 changed files with 41 additions and 13 deletions

View File

@@ -1,8 +1,8 @@
"""PINA Callbacks Implementations""" """PINA Callbacks Implementations"""
# from lightning.pytorch.callbacks import Callback
from pytorch_lightning.callbacks import Callback
import torch import torch
from pytorch_lightning.callbacks import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency from ..utils import check_consistency
@@ -12,19 +12,22 @@ class R3Refinement(Callback):
""" """
PINA Implementation of an R3 Refinement Callback. PINA Implementation of an R3 Refinement Callback.
This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search. This callback implements the R3 (Retain-Resample-Release) routine for
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those sampling new points based on adaptive search.
with low residuals. Points are sampled uniformly in all regions where sampling is needed. The algorithm incrementally accumulates collocation points in regions
of high PDE residuals, and releases those
with low residuals. Points are sampled uniformly in all regions
where sampling is needed.
.. seealso:: .. seealso::
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks Original Reference: Daw, Arka, et al. *Mitigating Propagation
Failures in Physics-informed Neural Networks
using Retain-Resample-Release (R3) Sampling. (2023)*. using Retain-Resample-Release (R3) Sampling. (2023)*.
DOI: `10.48550/arXiv.2207.02338 DOI: `10.48550/arXiv.2207.02338
<https://doi.org/10.48550/arXiv.2207.02338>`_ <https://doi.org/10.48550/arXiv.2207.02338>`_
:param int sample_every: Frequency for sampling. :param int sample_every: Frequency for sampling.
:raises ValueError: If `sample_every` is not an integer. :raises ValueError: If `sample_every` is not an integer.
Example: Example:
@@ -47,6 +50,17 @@ class R3Refinement(Callback):
# extract the solver and device from trainer # extract the solver and device from trainer
solver = trainer._model solver = trainer._model
device = trainer._accelerator_connector._accelerator_flag device = trainer._accelerator_connector._accelerator_flag
precision = trainer.precision
if precision == "64-true":
precision = torch.float64
elif precision == "32-true":
precision = torch.float32
else:
raise RuntimeError("Currently R3Refinement is only implemented "
"for precision '32-true' and '64-true', set "
"Trainer precision to match one of the "
"available precisions.")
# compute residual # compute residual
res_loss = {} res_loss = {}
@@ -55,10 +69,10 @@ class R3Refinement(Callback):
condition = solver.problem.conditions[location] condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location] pts = solver.problem.input_pts[location]
# send points to correct device # send points to correct device
pts = pts.to(device) pts = pts.to(device=device, dtype=precision)
pts = pts.requires_grad_(True) pts = pts.requires_grad_(True)
pts.retain_grad() pts.retain_grad()
# PINN loss: equation evaluated only on locations where sampling is needed # PINN loss: equation evaluated only for sampling locations
target = condition.equation.residual(pts, solver.forward(pts)) target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target)) tot_loss.append(torch.abs(target))
@@ -86,13 +100,14 @@ class R3Refinement(Callback):
for location in self._sampling_locations: for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location] pts = trainer._model.problem.input_pts[location]
labels = pts.labels labels = pts.labels
pts = pts.cpu().detach() pts = pts.cpu().detach().as_subclass(torch.Tensor)
residuals = res_loss[location].cpu() residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten() mask = (residuals > avg).flatten()
if any( if any(
mask mask
): # if there are residuals greater than averge we append them ): # if there are residuals greater than averge we append them
pts = pts[mask] # TODO masking remove labels # Fix the issue, masking remove labels
pts = (pts[mask]).as_subclass(LabelTensor)
pts.labels = labels pts.labels = labels
old_pts[location] = pts old_pts[location] = pts
tot_points += len(pts) tot_points += len(pts)
@@ -122,7 +137,8 @@ class R3Refinement(Callback):
""" """
Callback function called at the start of training. Callback function called at the start of training.
This method extracts the locations for sampling from the problem conditions and calculates the total population. This method extracts the locations for sampling from the problem
conditions and calculates the total population.
:param trainer: The trainer object managing the training process. :param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer :type trainer: pytorch_lightning.Trainer
@@ -151,7 +167,8 @@ class R3Refinement(Callback):
""" """
Callback function called at the end of each training epoch. Callback function called at the end of each training epoch.
This method triggers the R3 routine for refinement if the current epoch is a multiple of `_sample_every`. This method triggers the R3 routine for refinement if the current
epoch is a multiple of `_sample_every`.
:param trainer: The trainer object managing the training process. :param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer :type trainer: pytorch_lightning.Trainer

View File

@@ -73,3 +73,14 @@ def test_r3refinment_routine():
callbacks=[R3Refinement(sample_every=1)], callbacks=[R3Refinement(sample_every=1)],
max_epochs=5) max_epochs=5)
trainer.train() trainer.train()
def test_r3refinment_routine_double_precision():
model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables))
solver = PINN(problem=poisson_problem, model=model)
trainer = Trainer(solver=solver,
precision='64-true',
accelerator='cpu',
callbacks=[R3Refinement(sample_every=2)],
max_epochs=5)
trainer.train()