* solvers -> solver
* adaptive_functions -> adaptive_function
* callbacks -> callback
* operators -> operator
* pinns -> physics_informed_solver
* layers -> block
This commit is contained in:
Dario Coscia
2025-02-19 11:35:43 +01:00
committed by Nicola Demo
parent 810d215ca0
commit df673cad4e
90 changed files with 155 additions and 151 deletions

10
pina/callback/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
__all__ = [
"SwitchOptimizer",
"R3Refinement",
"MetricTracker",
"PINAProgressBar",
]
from .optimizer_callback import SwitchOptimizer
from .adaptive_refinment_callback import R3Refinement
from .processing_callback import MetricTracker, PINAProgressBar

View File

@@ -0,0 +1,174 @@
"""PINA Callbacks Implementations"""
import torch
from lightning.pytorch.callbacks import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency
class R3Refinement(Callback):
def __init__(self, sample_every):
"""
PINA Implementation of an R3 Refinement Callback.
This callback implements the R3 (Retain-Resample-Release) routine for
sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions
of high PDE residuals, and releases those
with low residuals. Points are sampled uniformly in all regions
where sampling is needed.
.. seealso::
Original Reference: Daw, Arka, et al. *Mitigating Propagation
Failures in Physics-informed Neural Networks
using Retain-Resample-Release (R3) Sampling. (2023)*.
DOI: `10.48550/arXiv.2207.02338
<https://doi.org/10.48550/arXiv.2207.02338>`_
:param int sample_every: Frequency for sampling.
:raises ValueError: If `sample_every` is not an integer.
Example:
>>> r3_callback = R3Refinement(sample_every=5)
"""
super().__init__()
# sample every
check_consistency(sample_every, int)
self._sample_every = sample_every
self._const_pts = None
def _compute_residual(self, trainer):
"""
Computes the residuals for a PINN object.
:return: the total loss, and pointwise loss.
:rtype: tuple
"""
# extract the solver and device from trainer
solver = trainer.solver
device = trainer._accelerator_connector._accelerator_flag
precision = trainer.precision
if precision == "64-true":
precision = torch.float64
elif precision == "32-true":
precision = torch.float32
else:
raise RuntimeError(
"Currently R3Refinement is only implemented "
"for precision '32-true' and '64-true', set "
"Trainer precision to match one of the "
"available precisions."
)
# compute residual
res_loss = {}
tot_loss = []
for location in self._sampling_locations: #TODO fix for new collector
condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location]
# send points to correct device
pts = pts.to(device=device, dtype=precision)
pts = pts.requires_grad_(True)
pts.retain_grad()
# PINN loss: equation evaluated only for sampling locations
target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))
print(tot_loss)
return torch.vstack(tot_loss), res_loss
def _r3_routine(self, trainer):
"""
R3 refinement main routine.
:param Trainer trainer: PINA Trainer.
"""
# compute residual (all device possible)
tot_loss, res_loss = self._compute_residual(trainer)
tot_loss = tot_loss.as_subclass(torch.Tensor)
# !!!!!! From now everything is performed on CPU !!!!!!
# average loss
avg = (tot_loss.mean()).to("cpu")
old_pts = {} # points to be retained
for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location]
labels = pts.labels
pts = pts.cpu().detach().as_subclass(torch.Tensor)
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
if any(mask): # append residuals greater than average
pts = (pts[mask]).as_subclass(LabelTensor)
pts.labels = labels
old_pts[location] = pts
numb_pts = self._const_pts[location] - len(old_pts[location])
# sample new points
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location]
)
else: # if no res greater than average, samples all uniformly
numb_pts = self._const_pts[location]
# sample new points
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location]
)
# adding previous population points
trainer._model.problem.add_points(old_pts)
# update dataloader
trainer._create_or_update_loader()
def on_train_start(self, trainer, _):
"""
Callback function called at the start of training.
This method extracts the locations for sampling from the problem
conditions and calculates the total population.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param _: Placeholder argument (not used).
:return: None
:rtype: None
"""
# extract locations for sampling
problem = trainer.solver.problem
locations = []
for condition_name in problem.conditions:
condition = problem.conditions[condition_name]
if hasattr(condition, "location"):
locations.append(condition_name)
self._sampling_locations = locations
# extract total population
const_pts = {} # for each location, store the # of pts to keep constant
for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location]
const_pts[location] = len(pts)
self._const_pts = const_pts
def on_train_epoch_end(self, trainer, __):
"""
Callback function called at the end of each training epoch.
This method triggers the R3 routine for refinement if the current
epoch is a multiple of `_sample_every`.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param __: Placeholder argument (not used).
:return: None
:rtype: None
"""
if trainer.current_epoch % self._sample_every == 0:
self._r3_routine(trainer)

View File

@@ -0,0 +1,65 @@
"""PINA Callbacks Implementations"""
from lightning.pytorch.callbacks import Callback
import torch
from ..utils import check_consistency
from pina.optim import TorchOptimizer
class SwitchOptimizer(Callback):
def __init__(self, new_optimizers, epoch_switch):
"""
PINA Implementation of a Lightning Callback to switch optimizer during
training.
This callback allows for switching between different optimizers during
training, enabling the exploration of multiple optimization strategies
without the need to stop training.
:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` or a list of them for multiple
model solver.
:type new_optimizers: pina.optim.TorchOptimizer | list
:param epoch_switch: The epoch at which to switch to the new optimizer.
:type epoch_switch: int
Example:
>>> switch_callback = SwitchOptimizer(new_optimizers=optimizer,
>>> epoch_switch=10)
"""
super().__init__()
if epoch_switch < 1:
raise ValueError("epoch_switch must be greater than one.")
if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers]
# check type consistency
for optimizer in new_optimizers:
check_consistency(optimizer, TorchOptimizer)
check_consistency(epoch_switch, int)
# save new optimizers
self._new_optimizers = new_optimizers
self._epoch_switch = epoch_switch
def on_train_epoch_start(self, trainer, __):
"""
Callback function to switch optimizer at the start of each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param _: Placeholder argument (not used).
:return: None
:rtype: None
"""
if trainer.current_epoch == self._epoch_switch:
optims = []
for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver.models[idx].parameters())
optims.append(optim.instance)
trainer.optimizers = optims

View File

@@ -0,0 +1,152 @@
"""PINA Callbacks Implementations"""
import torch
import copy
from lightning.pytorch.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import (
get_standard_metrics,
)
from pina.utils import check_consistency
class MetricTracker(Callback):
def __init__(self, metrics_to_track=None):
"""
Lightning Callback for Metric Tracking.
Tracks specific metrics during the training process.
:ivar _collection: A list to store collected metrics after each epoch.
:param metrics_to_track: List of metrics to track. Defaults to train/val loss.
:type metrics_to_track: list, optional
"""
super().__init__()
self._collection = []
# Default to tracking 'train_loss' and 'val_loss' if not specified
self.metrics_to_track = metrics_to_track or ['train_loss', 'val_loss']
def on_train_epoch_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: The model being trained (not used here).
"""
# Track metrics after the first epoch onwards
if trainer.current_epoch > 0:
# Append only the tracked metrics to avoid unnecessary data
tracked_metrics = {
k: v for k, v in trainer.logged_metrics.items()
if k in self.metrics_to_track
}
self._collection.append(copy.deepcopy(tracked_metrics))
@property
def metrics(self):
"""
Aggregate collected metrics over all epochs.
:return: A dictionary containing aggregated metric values.
:rtype: dict
"""
if not self._collection:
return {}
# Get intersection of keys across all collected dictionaries
common_keys = set(self._collection[0]).intersection(*self._collection[1:])
# Stack the metric values for common keys and return
return {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys if k in self.metrics_to_track
}
class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, metrics="val", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.
This class provides functionality to display only relevant metrics
during the training process.
:param metrics: Logged metrics to display during the training. It should
be a subset of the conditions keys defined in
:obj:`pina.condition.Condition`.
:type metrics: str | list(str) | tuple(str)
:Keyword Arguments:
The additional keyword arguments specify the progress bar
and can be choosen from the `pytorch-lightning
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
Example:
>>> pbar = PINAProgressBar(['mean'])
>>> # ... Perform training ...
>>> trainer = Trainer(solver, callback=[pbar])
"""
super().__init__(**kwargs)
# check consistency
if not isinstance(metrics, (list, tuple)):
metrics = [metrics]
check_consistency(metrics, str)
self._sorted_metrics = metrics
def get_metrics(self, trainer, pl_module):
r"""Combines progress bar metrics collected from the trainer with
standard metrics from get_standard_metrics.
Implement this to override the items displayed in the progress bar.
The progress bar metrics are sorted according to ``metrics``.
Here is an example of how to override the defaults:
.. code-block:: python
def get_metrics(self, trainer, model):
# don't show the version number
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items
:return: Dictionary with the items to be displayed in the progress bar.
:rtype: tuple(dict)
"""
standard_metrics = get_standard_metrics(trainer)
pbar_metrics = trainer.progress_bar_metrics
if pbar_metrics:
pbar_metrics = {
key: pbar_metrics[key] for key in self._sorted_metrics
}
return {**standard_metrics, **pbar_metrics}
def on_fit_start(self, trainer, pl_module):
"""
Check that the metrics defined in the initialization are available,
i.e. are correctly logged.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: Placeholder argument.
"""
# Check if all keys in sort_keys are present in the dictionary
for key in self._sorted_metrics:
if (
key not in trainer.solver.problem.conditions.keys()
and key != "train" and key != "val"
):
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix
self._sorted_metrics = [
metric + "_loss" for metric in self._sorted_metrics
]
return super().on_fit_start(trainer, pl_module)