🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -1,4 +1,4 @@
__all__ = ['SwitchOptimizer', 'R3Refinement', 'MetricTracker']
__all__ = ["SwitchOptimizer", "R3Refinement", "MetricTracker"]
from .optimizer_callbacks import SwitchOptimizer
from .adaptive_refinment_callbacks import R3Refinement

View File

@@ -1,4 +1,4 @@
'''PINA Callbacks Implementations'''
"""PINA Callbacks Implementations"""
# from lightning.pytorch.callbacks import Callback
from pytorch_lightning.callbacks import Callback
@@ -8,18 +8,17 @@ from ..utils import check_consistency
class R3Refinement(Callback):
def __init__(self, sample_every):
"""
PINA Implementation of an R3 Refinement Callback.
This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those
The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those
with low residuals. Points are sampled uniformly in all regions where sampling is needed.
.. seealso::
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks
Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks
using Retain-Resample-Release (R3) Sampling. (2023)*.
DOI: `10.48550/arXiv.2207.02338
<https://doi.org/10.48550/arXiv.2207.02338>`_
@@ -79,7 +78,7 @@ class R3Refinement(Callback):
# !!!!!! From now everything is performed on CPU !!!!!!
# average loss
avg = (tot_loss.mean()).to('cpu')
avg = (tot_loss.mean()).to("cpu")
# points to keep
old_pts = {}
@@ -90,25 +89,29 @@ class R3Refinement(Callback):
pts = pts.cpu().detach()
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
if any(mask): # if there are residuals greater than averge we append them
pts = pts[mask] # TODO masking remove labels
if any(
mask
): # if there are residuals greater than averge we append them
pts = pts[mask] # TODO masking remove labels
pts.labels = labels
old_pts[location] = pts
tot_points += len(pts)
# extract new points to sample uniformally for each location
n_points = (self._tot_pop_numb - tot_points) // len(
self._sampling_locations)
self._sampling_locations
)
remainder = (self._tot_pop_numb - tot_points) % len(
self._sampling_locations)
self._sampling_locations
)
n_uniform_points = [n_points] * len(self._sampling_locations)
n_uniform_points[-1] += remainder
# sample new points
for numb_pts, loc in zip(n_uniform_points, self._sampling_locations):
trainer._model.problem.discretise_domain(numb_pts,
'random',
locations=[loc])
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[loc]
)
# adding previous population points
trainer._model.problem.add_points(old_pts)
@@ -133,7 +136,7 @@ class R3Refinement(Callback):
locations = []
for condition_name in problem.conditions:
condition = problem.conditions[condition_name]
if hasattr(condition, 'location'):
if hasattr(condition, "location"):
locations.append(condition_name)
self._sampling_locations = locations

View File

@@ -1,4 +1,4 @@
'''PINA Callbacks Implementations'''
"""PINA Callbacks Implementations"""
from pytorch_lightning.callbacks import Callback
import torch
@@ -14,7 +14,7 @@ class SwitchOptimizer(Callback):
This callback allows for switching between different optimizers during training, enabling
the exploration of multiple optimization strategies without the need to stop training.
:param new_optimizers: The model optimizers to switch to. Can be a single
:param new_optimizers: The model optimizers to switch to. Can be a single
:class:`torch.optim.Optimizer` or a list of them for multiple model solvers.
:type new_optimizers: torch.optim.Optimizer | list
:param new_optimizers_kwargs: The keyword arguments for the new optimizers. Can be a single dictionary
@@ -23,7 +23,7 @@ class SwitchOptimizer(Callback):
:param epoch_switch: The epoch at which to switch to the new optimizer.
:type epoch_switch: int
:raises ValueError: If `epoch_switch` is less than 1 or if there is a mismatch in the number of
:raises ValueError: If `epoch_switch` is less than 1 or if there is a mismatch in the number of
optimizers and their corresponding keyword argument dictionaries.
Example:
@@ -39,7 +39,7 @@ class SwitchOptimizer(Callback):
check_consistency(epoch_switch, int)
if epoch_switch < 1:
raise ValueError('epoch_switch must be greater than one.')
raise ValueError("epoch_switch must be greater than one.")
if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers]
@@ -48,10 +48,12 @@ class SwitchOptimizer(Callback):
len_optimizer_kwargs = len(new_optimizers_kwargs)
if len_optimizer_kwargs != len_optimizer:
raise ValueError('You must define one dictionary of keyword'
' arguments for each optimizers.'
f' Got {len_optimizer} optimizers, and'
f' {len_optimizer_kwargs} dicitionaries')
raise ValueError(
"You must define one dictionary of keyword"
" arguments for each optimizers."
f" Got {len_optimizer} optimizers, and"
f" {len_optimizer_kwargs} dicitionaries"
)
# save new optimizers
self._new_optimizers = new_optimizers
@@ -72,9 +74,12 @@ class SwitchOptimizer(Callback):
if trainer.current_epoch == self._epoch_switch:
optims = []
for idx, (optim, optim_kwargs) in enumerate(
zip(self._new_optimizers, self._new_optimizers_kwargs)):
zip(self._new_optimizers, self._new_optimizers_kwargs)
):
optims.append(
optim(trainer._model.models[idx].parameters(),
**optim_kwargs))
optim(
trainer._model.models[idx].parameters(), **optim_kwargs
)
)
trainer.optimizers = optims

View File

@@ -1,4 +1,4 @@
'''PINA Callbacks Implementations'''
"""PINA Callbacks Implementations"""
from pytorch_lightning.callbacks import Callback
import torch
@@ -6,7 +6,7 @@ import copy
class MetricTracker(Callback):
def __init__(self):
"""
PINA Implementation of a Lightning Callback for Metric Tracking.
@@ -39,8 +39,9 @@ class MetricTracker(Callback):
:return: None
:rtype: None
"""
self._collection.append(copy.deepcopy(
trainer.logged_metrics)) # track them
self._collection.append(
copy.deepcopy(trainer.logged_metrics)
) # track them
@property
def metrics(self):