🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
'''PINA Callbacks Implementations'''
|
||||
"""PINA Callbacks Implementations"""
|
||||
|
||||
# from lightning.pytorch.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
@@ -8,18 +8,17 @@ from ..utils import check_consistency
|
||||
|
||||
class R3Refinement(Callback):
|
||||
|
||||
|
||||
def __init__(self, sample_every):
|
||||
"""
|
||||
PINA Implementation of an R3 Refinement Callback.
|
||||
|
||||
This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search.
|
||||
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those
|
||||
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those
|
||||
with low residuals. Points are sampled uniformly in all regions where sampling is needed.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks
|
||||
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks
|
||||
using Retain-Resample-Release (R3) Sampling. (2023)*.
|
||||
DOI: `10.48550/arXiv.2207.02338
|
||||
<https://doi.org/10.48550/arXiv.2207.02338>`_
|
||||
@@ -79,7 +78,7 @@ class R3Refinement(Callback):
|
||||
# !!!!!! From now everything is performed on CPU !!!!!!
|
||||
|
||||
# average loss
|
||||
avg = (tot_loss.mean()).to('cpu')
|
||||
avg = (tot_loss.mean()).to("cpu")
|
||||
|
||||
# points to keep
|
||||
old_pts = {}
|
||||
@@ -90,25 +89,29 @@ class R3Refinement(Callback):
|
||||
pts = pts.cpu().detach()
|
||||
residuals = res_loss[location].cpu()
|
||||
mask = (residuals > avg).flatten()
|
||||
if any(mask): # if there are residuals greater than averge we append them
|
||||
pts = pts[mask] # TODO masking remove labels
|
||||
if any(
|
||||
mask
|
||||
): # if there are residuals greater than averge we append them
|
||||
pts = pts[mask] # TODO masking remove labels
|
||||
pts.labels = labels
|
||||
old_pts[location] = pts
|
||||
tot_points += len(pts)
|
||||
|
||||
# extract new points to sample uniformally for each location
|
||||
n_points = (self._tot_pop_numb - tot_points) // len(
|
||||
self._sampling_locations)
|
||||
self._sampling_locations
|
||||
)
|
||||
remainder = (self._tot_pop_numb - tot_points) % len(
|
||||
self._sampling_locations)
|
||||
self._sampling_locations
|
||||
)
|
||||
n_uniform_points = [n_points] * len(self._sampling_locations)
|
||||
n_uniform_points[-1] += remainder
|
||||
|
||||
# sample new points
|
||||
for numb_pts, loc in zip(n_uniform_points, self._sampling_locations):
|
||||
trainer._model.problem.discretise_domain(numb_pts,
|
||||
'random',
|
||||
locations=[loc])
|
||||
trainer._model.problem.discretise_domain(
|
||||
numb_pts, "random", locations=[loc]
|
||||
)
|
||||
# adding previous population points
|
||||
trainer._model.problem.add_points(old_pts)
|
||||
|
||||
@@ -133,7 +136,7 @@ class R3Refinement(Callback):
|
||||
locations = []
|
||||
for condition_name in problem.conditions:
|
||||
condition = problem.conditions[condition_name]
|
||||
if hasattr(condition, 'location'):
|
||||
if hasattr(condition, "location"):
|
||||
locations.append(condition_name)
|
||||
self._sampling_locations = locations
|
||||
|
||||
|
||||
Reference in New Issue
Block a user