🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -1,4 +1,4 @@
'''PINA Callbacks Implementations'''
"""PINA Callbacks Implementations"""
# from lightning.pytorch.callbacks import Callback
from pytorch_lightning.callbacks import Callback
@@ -8,18 +8,17 @@ from ..utils import check_consistency
class R3Refinement(Callback):
def __init__(self, sample_every):
"""
PINA Implementation of an R3 Refinement Callback.
This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those
with low residuals. Points are sampled uniformly in all regions where sampling is needed.
.. seealso::
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks
using Retain-Resample-Release (R3) Sampling. (2023)*.
DOI: `10.48550/arXiv.2207.02338
<https://doi.org/10.48550/arXiv.2207.02338>`_
@@ -79,7 +78,7 @@ class R3Refinement(Callback):
# !!!!!! From now everything is performed on CPU !!!!!!
# average loss
avg = (tot_loss.mean()).to('cpu')
avg = (tot_loss.mean()).to("cpu")
# points to keep
old_pts = {}
@@ -90,25 +89,29 @@ class R3Refinement(Callback):
pts = pts.cpu().detach()
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
if any(mask): # if there are residuals greater than averge we append them
pts = pts[mask] # TODO masking remove labels
if any(
mask
): # if there are residuals greater than averge we append them
pts = pts[mask] # TODO masking remove labels
pts.labels = labels
old_pts[location] = pts
tot_points += len(pts)
# extract new points to sample uniformally for each location
n_points = (self._tot_pop_numb - tot_points) // len(
self._sampling_locations)
self._sampling_locations
)
remainder = (self._tot_pop_numb - tot_points) % len(
self._sampling_locations)
self._sampling_locations
)
n_uniform_points = [n_points] * len(self._sampling_locations)
n_uniform_points[-1] += remainder
# sample new points
for numb_pts, loc in zip(n_uniform_points, self._sampling_locations):
trainer._model.problem.discretise_domain(numb_pts,
'random',
locations=[loc])
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[loc]
)
# adding previous population points
trainer._model.problem.add_points(old_pts)
@@ -133,7 +136,7 @@ class R3Refinement(Callback):
locations = []
for condition_name in problem.conditions:
condition = problem.conditions[condition_name]
if hasattr(condition, 'location'):
if hasattr(condition, "location"):
locations.append(condition_name)
self._sampling_locations = locations