R3Refinment double precision training fix (#277)
* r3 ref double precision fix * fix label missing --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.Home>
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
"""PINA Callbacks Implementations"""
|
"""PINA Callbacks Implementations"""
|
||||||
|
|
||||||
# from lightning.pytorch.callbacks import Callback
|
|
||||||
from pytorch_lightning.callbacks import Callback
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
from ..label_tensor import LabelTensor
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
@@ -12,19 +12,22 @@ class R3Refinement(Callback):
|
|||||||
"""
|
"""
|
||||||
PINA Implementation of an R3 Refinement Callback.
|
PINA Implementation of an R3 Refinement Callback.
|
||||||
|
|
||||||
This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search.
|
This callback implements the R3 (Retain-Resample-Release) routine for
|
||||||
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those
|
sampling new points based on adaptive search.
|
||||||
with low residuals. Points are sampled uniformly in all regions where sampling is needed.
|
The algorithm incrementally accumulates collocation points in regions
|
||||||
|
of high PDE residuals, and releases those
|
||||||
|
with low residuals. Points are sampled uniformly in all regions
|
||||||
|
where sampling is needed.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks
|
Original Reference: Daw, Arka, et al. *Mitigating Propagation
|
||||||
|
Failures in Physics-informed Neural Networks
|
||||||
using Retain-Resample-Release (R3) Sampling. (2023)*.
|
using Retain-Resample-Release (R3) Sampling. (2023)*.
|
||||||
DOI: `10.48550/arXiv.2207.02338
|
DOI: `10.48550/arXiv.2207.02338
|
||||||
<https://doi.org/10.48550/arXiv.2207.02338>`_
|
<https://doi.org/10.48550/arXiv.2207.02338>`_
|
||||||
|
|
||||||
:param int sample_every: Frequency for sampling.
|
:param int sample_every: Frequency for sampling.
|
||||||
|
|
||||||
:raises ValueError: If `sample_every` is not an integer.
|
:raises ValueError: If `sample_every` is not an integer.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -47,6 +50,17 @@ class R3Refinement(Callback):
|
|||||||
# extract the solver and device from trainer
|
# extract the solver and device from trainer
|
||||||
solver = trainer._model
|
solver = trainer._model
|
||||||
device = trainer._accelerator_connector._accelerator_flag
|
device = trainer._accelerator_connector._accelerator_flag
|
||||||
|
precision = trainer.precision
|
||||||
|
if precision == "64-true":
|
||||||
|
precision = torch.float64
|
||||||
|
elif precision == "32-true":
|
||||||
|
precision = torch.float32
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Currently R3Refinement is only implemented "
|
||||||
|
"for precision '32-true' and '64-true', set "
|
||||||
|
"Trainer precision to match one of the "
|
||||||
|
"available precisions.")
|
||||||
|
|
||||||
|
|
||||||
# compute residual
|
# compute residual
|
||||||
res_loss = {}
|
res_loss = {}
|
||||||
@@ -55,10 +69,10 @@ class R3Refinement(Callback):
|
|||||||
condition = solver.problem.conditions[location]
|
condition = solver.problem.conditions[location]
|
||||||
pts = solver.problem.input_pts[location]
|
pts = solver.problem.input_pts[location]
|
||||||
# send points to correct device
|
# send points to correct device
|
||||||
pts = pts.to(device)
|
pts = pts.to(device=device, dtype=precision)
|
||||||
pts = pts.requires_grad_(True)
|
pts = pts.requires_grad_(True)
|
||||||
pts.retain_grad()
|
pts.retain_grad()
|
||||||
# PINN loss: equation evaluated only on locations where sampling is needed
|
# PINN loss: equation evaluated only for sampling locations
|
||||||
target = condition.equation.residual(pts, solver.forward(pts))
|
target = condition.equation.residual(pts, solver.forward(pts))
|
||||||
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||||
tot_loss.append(torch.abs(target))
|
tot_loss.append(torch.abs(target))
|
||||||
@@ -86,13 +100,14 @@ class R3Refinement(Callback):
|
|||||||
for location in self._sampling_locations:
|
for location in self._sampling_locations:
|
||||||
pts = trainer._model.problem.input_pts[location]
|
pts = trainer._model.problem.input_pts[location]
|
||||||
labels = pts.labels
|
labels = pts.labels
|
||||||
pts = pts.cpu().detach()
|
pts = pts.cpu().detach().as_subclass(torch.Tensor)
|
||||||
residuals = res_loss[location].cpu()
|
residuals = res_loss[location].cpu()
|
||||||
mask = (residuals > avg).flatten()
|
mask = (residuals > avg).flatten()
|
||||||
if any(
|
if any(
|
||||||
mask
|
mask
|
||||||
): # if there are residuals greater than averge we append them
|
): # if there are residuals greater than averge we append them
|
||||||
pts = pts[mask] # TODO masking remove labels
|
# Fix the issue, masking remove labels
|
||||||
|
pts = (pts[mask]).as_subclass(LabelTensor)
|
||||||
pts.labels = labels
|
pts.labels = labels
|
||||||
old_pts[location] = pts
|
old_pts[location] = pts
|
||||||
tot_points += len(pts)
|
tot_points += len(pts)
|
||||||
@@ -122,7 +137,8 @@ class R3Refinement(Callback):
|
|||||||
"""
|
"""
|
||||||
Callback function called at the start of training.
|
Callback function called at the start of training.
|
||||||
|
|
||||||
This method extracts the locations for sampling from the problem conditions and calculates the total population.
|
This method extracts the locations for sampling from the problem
|
||||||
|
conditions and calculates the total population.
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
:param trainer: The trainer object managing the training process.
|
||||||
:type trainer: pytorch_lightning.Trainer
|
:type trainer: pytorch_lightning.Trainer
|
||||||
@@ -151,7 +167,8 @@ class R3Refinement(Callback):
|
|||||||
"""
|
"""
|
||||||
Callback function called at the end of each training epoch.
|
Callback function called at the end of each training epoch.
|
||||||
|
|
||||||
This method triggers the R3 routine for refinement if the current epoch is a multiple of `_sample_every`.
|
This method triggers the R3 routine for refinement if the current
|
||||||
|
epoch is a multiple of `_sample_every`.
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
:param trainer: The trainer object managing the training process.
|
||||||
:type trainer: pytorch_lightning.Trainer
|
:type trainer: pytorch_lightning.Trainer
|
||||||
|
|||||||
@@ -73,3 +73,14 @@ def test_r3refinment_routine():
|
|||||||
callbacks=[R3Refinement(sample_every=1)],
|
callbacks=[R3Refinement(sample_every=1)],
|
||||||
max_epochs=5)
|
max_epochs=5)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
def test_r3refinment_routine_double_precision():
|
||||||
|
model = FeedForward(len(poisson_problem.input_variables),
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
solver = PINN(problem=poisson_problem, model=model)
|
||||||
|
trainer = Trainer(solver=solver,
|
||||||
|
precision='64-true',
|
||||||
|
accelerator='cpu',
|
||||||
|
callbacks=[R3Refinement(sample_every=2)],
|
||||||
|
max_epochs=5)
|
||||||
|
trainer.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user