R3Refinment double precision training fix (#277)
* r3 ref double precision fix * fix label missing --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.Home>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
"""PINA Callbacks Implementations"""
|
||||
|
||||
# from lightning.pytorch.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
@@ -12,19 +12,22 @@ class R3Refinement(Callback):
|
||||
"""
|
||||
PINA Implementation of an R3 Refinement Callback.
|
||||
|
||||
This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search.
|
||||
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those
|
||||
with low residuals. Points are sampled uniformly in all regions where sampling is needed.
|
||||
This callback implements the R3 (Retain-Resample-Release) routine for
|
||||
sampling new points based on adaptive search.
|
||||
The algorithm incrementally accumulates collocation points in regions
|
||||
of high PDE residuals, and releases those
|
||||
with low residuals. Points are sampled uniformly in all regions
|
||||
where sampling is needed.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks
|
||||
Original Reference: Daw, Arka, et al. *Mitigating Propagation
|
||||
Failures in Physics-informed Neural Networks
|
||||
using Retain-Resample-Release (R3) Sampling. (2023)*.
|
||||
DOI: `10.48550/arXiv.2207.02338
|
||||
<https://doi.org/10.48550/arXiv.2207.02338>`_
|
||||
|
||||
:param int sample_every: Frequency for sampling.
|
||||
|
||||
:raises ValueError: If `sample_every` is not an integer.
|
||||
|
||||
Example:
|
||||
@@ -47,6 +50,17 @@ class R3Refinement(Callback):
|
||||
# extract the solver and device from trainer
|
||||
solver = trainer._model
|
||||
device = trainer._accelerator_connector._accelerator_flag
|
||||
precision = trainer.precision
|
||||
if precision == "64-true":
|
||||
precision = torch.float64
|
||||
elif precision == "32-true":
|
||||
precision = torch.float32
|
||||
else:
|
||||
raise RuntimeError("Currently R3Refinement is only implemented "
|
||||
"for precision '32-true' and '64-true', set "
|
||||
"Trainer precision to match one of the "
|
||||
"available precisions.")
|
||||
|
||||
|
||||
# compute residual
|
||||
res_loss = {}
|
||||
@@ -55,10 +69,10 @@ class R3Refinement(Callback):
|
||||
condition = solver.problem.conditions[location]
|
||||
pts = solver.problem.input_pts[location]
|
||||
# send points to correct device
|
||||
pts = pts.to(device)
|
||||
pts = pts.to(device=device, dtype=precision)
|
||||
pts = pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
# PINN loss: equation evaluated only on locations where sampling is needed
|
||||
# PINN loss: equation evaluated only for sampling locations
|
||||
target = condition.equation.residual(pts, solver.forward(pts))
|
||||
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||
tot_loss.append(torch.abs(target))
|
||||
@@ -86,13 +100,14 @@ class R3Refinement(Callback):
|
||||
for location in self._sampling_locations:
|
||||
pts = trainer._model.problem.input_pts[location]
|
||||
labels = pts.labels
|
||||
pts = pts.cpu().detach()
|
||||
pts = pts.cpu().detach().as_subclass(torch.Tensor)
|
||||
residuals = res_loss[location].cpu()
|
||||
mask = (residuals > avg).flatten()
|
||||
if any(
|
||||
mask
|
||||
): # if there are residuals greater than averge we append them
|
||||
pts = pts[mask] # TODO masking remove labels
|
||||
# Fix the issue, masking remove labels
|
||||
pts = (pts[mask]).as_subclass(LabelTensor)
|
||||
pts.labels = labels
|
||||
old_pts[location] = pts
|
||||
tot_points += len(pts)
|
||||
@@ -122,7 +137,8 @@ class R3Refinement(Callback):
|
||||
"""
|
||||
Callback function called at the start of training.
|
||||
|
||||
This method extracts the locations for sampling from the problem conditions and calculates the total population.
|
||||
This method extracts the locations for sampling from the problem
|
||||
conditions and calculates the total population.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
@@ -151,7 +167,8 @@ class R3Refinement(Callback):
|
||||
"""
|
||||
Callback function called at the end of each training epoch.
|
||||
|
||||
This method triggers the R3 routine for refinement if the current epoch is a multiple of `_sample_every`.
|
||||
This method triggers the R3 routine for refinement if the current
|
||||
epoch is a multiple of `_sample_every`.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
|
||||
@@ -73,3 +73,14 @@ def test_r3refinment_routine():
|
||||
callbacks=[R3Refinement(sample_every=1)],
|
||||
max_epochs=5)
|
||||
trainer.train()
|
||||
|
||||
def test_r3refinment_routine_double_precision():
|
||||
model = FeedForward(len(poisson_problem.input_variables),
|
||||
len(poisson_problem.output_variables))
|
||||
solver = PINN(problem=poisson_problem, model=model)
|
||||
trainer = Trainer(solver=solver,
|
||||
precision='64-true',
|
||||
accelerator='cpu',
|
||||
callbacks=[R3Refinement(sample_every=2)],
|
||||
max_epochs=5)
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user