Update callbacks and tests (#482)

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
Dario Coscia
2025-03-13 16:19:38 +01:00
committed by FilippoOlivo
parent 18d178ab3a
commit 9dab6380f8
8 changed files with 264 additions and 229 deletions

View File

@@ -1,5 +1,6 @@
"""PINA Callbacks Implementations"""
import importlib.metadata
import torch
from lightning.pytorch.callbacks import Callback
from ..label_tensor import LabelTensor
@@ -7,17 +8,17 @@ from ..utils import check_consistency
class R3Refinement(Callback):
"""
PINA Implementation of an R3 Refinement Callback.
"""
def __init__(self, sample_every):
"""
PINA Implementation of an R3 Refinement Callback.
This callback implements the R3 (Retain-Resample-Release) routine for
sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions
of high PDE residuals, and releases those
with low residuals. Points are sampled uniformly in all regions
where sampling is needed.
of high PDE residuals, and releases those with low residuals.
Points are sampled uniformly in all regions where sampling is needed.
.. seealso::
@@ -33,142 +34,148 @@ class R3Refinement(Callback):
Example:
>>> r3_callback = R3Refinement(sample_every=5)
"""
super().__init__()
raise NotImplementedError(
"R3Refinement callback is being refactored in the pina "
f"{importlib.metadata.metadata('pina-mathlab')['Version']} "
"version. Please use version 0.1 if R3Refinement is required."
)
# sample every
check_consistency(sample_every, int)
self._sample_every = sample_every
self._const_pts = None
# super().__init__()
def _compute_residual(self, trainer):
"""
Computes the residuals for a PINN object.
# # sample every
# check_consistency(sample_every, int)
# self._sample_every = sample_every
# self._const_pts = None
:return: the total loss, and pointwise loss.
:rtype: tuple
"""
# def _compute_residual(self, trainer):
# """
# Computes the residuals for a PINN object.
# extract the solver and device from trainer
solver = trainer.solver
device = trainer._accelerator_connector._accelerator_flag
precision = trainer.precision
if precision == "64-true":
precision = torch.float64
elif precision == "32-true":
precision = torch.float32
else:
raise RuntimeError(
"Currently R3Refinement is only implemented "
"for precision '32-true' and '64-true', set "
"Trainer precision to match one of the "
"available precisions."
)
# :return: the total loss, and pointwise loss.
# :rtype: tuple
# """
# compute residual
res_loss = {}
tot_loss = []
for location in self._sampling_locations: # TODO fix for new collector
condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location]
# send points to correct device
pts = pts.to(device=device, dtype=precision)
pts = pts.requires_grad_(True)
pts.retain_grad()
# PINN loss: equation evaluated only for sampling locations
target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))
# # extract the solver and device from trainer
# solver = trainer.solver
# device = trainer._accelerator_connector._accelerator_flag
# precision = trainer.precision
# if precision == "64-true":
# precision = torch.float64
# elif precision == "32-true":
# precision = torch.float32
# else:
# raise RuntimeError(
# "Currently R3Refinement is only implemented "
# "for precision '32-true' and '64-true', set "
# "Trainer precision to match one of the "
# "available precisions."
# )
print(tot_loss)
# # compute residual
# res_loss = {}
# tot_loss = []
# for location in self._sampling_locations:
# condition = solver.problem.conditions[location]
# pts = solver.problem.input_pts[location]
# # send points to correct device
# pts = pts.to(device=device, dtype=precision)
# pts = pts.requires_grad_(True)
# pts.retain_grad()
# # PINN loss: equation evaluated only for sampling locations
# target = condition.equation.residual(pts, solver.forward(pts))
# res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
# tot_loss.append(torch.abs(target))
return torch.vstack(tot_loss), res_loss
# print(tot_loss)
def _r3_routine(self, trainer):
"""
R3 refinement main routine.
# return torch.vstack(tot_loss), res_loss
:param Trainer trainer: PINA Trainer.
"""
# compute residual (all device possible)
tot_loss, res_loss = self._compute_residual(trainer)
tot_loss = tot_loss.as_subclass(torch.Tensor)
# def _r3_routine(self, trainer):
# """
# R3 refinement main routine.
# !!!!!! From now everything is performed on CPU !!!!!!
# :param Trainer trainer: PINA Trainer.
# """
# # compute residual (all device possible)
# tot_loss, res_loss = self._compute_residual(trainer)
# tot_loss = tot_loss.as_subclass(torch.Tensor)
# average loss
avg = (tot_loss.mean()).to("cpu")
old_pts = {} # points to be retained
for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location]
labels = pts.labels
pts = pts.cpu().detach().as_subclass(torch.Tensor)
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
if any(mask): # append residuals greater than average
pts = (pts[mask]).as_subclass(LabelTensor)
pts.labels = labels
old_pts[location] = pts
numb_pts = self._const_pts[location] - len(old_pts[location])
# sample new points
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location]
)
# # !!!!!! From now everything is performed on CPU !!!!!!
else: # if no res greater than average, samples all uniformly
numb_pts = self._const_pts[location]
# sample new points
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location]
)
# adding previous population points
trainer._model.problem.add_points(old_pts)
# # average loss
# avg = (tot_loss.mean()).to("cpu")
# old_pts = {} # points to be retained
# for location in self._sampling_locations:
# pts = trainer._model.problem.input_pts[location]
# labels = pts.labels
# pts = pts.cpu().detach().as_subclass(torch.Tensor)
# residuals = res_loss[location].cpu()
# mask = (residuals > avg).flatten()
# if any(mask): # append residuals greater than average
# pts = (pts[mask]).as_subclass(LabelTensor)
# pts.labels = labels
# old_pts[location] = pts
# numb_pts = self._const_pts[location] - len(old_pts[location])
# # sample new points
# trainer._model.problem.discretise_domain(
# numb_pts, "random", locations=[location]
# )
# update dataloader
trainer._create_or_update_loader()
# else: # if no res greater than average, samples all uniformly
# numb_pts = self._const_pts[location]
# # sample new points
# trainer._model.problem.discretise_domain(
# numb_pts, "random", locations=[location]
# )
# # adding previous population points
# trainer._model.problem.add_points(old_pts)
def on_train_start(self, trainer, _):
"""
Callback function called at the start of training.
# # update dataloader
# trainer._create_or_update_loader()
This method extracts the locations for sampling from the problem
conditions and calculates the total population.
# def on_train_start(self, trainer, _):
# """
# Callback function called at the start of training.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param _: Placeholder argument (not used).
# This method extracts the locations for sampling from the problem
# conditions and calculates the total population.
:return: None
:rtype: None
"""
# extract locations for sampling
problem = trainer.solver.problem
locations = []
for condition_name in problem.conditions:
condition = problem.conditions[condition_name]
if hasattr(condition, "location"):
locations.append(condition_name)
self._sampling_locations = locations
# :param trainer: The trainer object managing the training process.
# :type trainer: pytorch_lightning.Trainer
# :param _: Placeholder argument (not used).
# extract total population
const_pts = {} # for each location, store the # of pts to keep constant
for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location]
const_pts[location] = len(pts)
self._const_pts = const_pts
# :return: None
# :rtype: None
# """
# # extract locations for sampling
# problem = trainer.solver.problem
# locations = []
# for condition_name in problem.conditions:
# condition = problem.conditions[condition_name]
# if hasattr(condition, "location"):
# locations.append(condition_name)
# self._sampling_locations = locations
def on_train_epoch_end(self, trainer, __):
"""
Callback function called at the end of each training epoch.
# # extract total population
# const_pts = {} # for each location, store the pts to keep constant
# for location in self._sampling_locations:
# pts = trainer._model.problem.input_pts[location]
# const_pts[location] = len(pts)
# self._const_pts = const_pts
This method triggers the R3 routine for refinement if the current
epoch is a multiple of `_sample_every`.
# def on_train_epoch_end(self, trainer, __):
# """
# Callback function called at the end of each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param __: Placeholder argument (not used).
# This method triggers the R3 routine for refinement if the current
# epoch is a multiple of `_sample_every`.
:return: None
:rtype: None
"""
if trainer.current_epoch % self._sample_every == 0:
self._r3_routine(trainer)
# :param trainer: The trainer object managing the training process.
# :type trainer: pytorch_lightning.Trainer
# :param __: Placeholder argument (not used).
# :return: None
# :rtype: None
# """
# if trainer.current_epoch % self._sample_every == 0:
# self._r3_routine(trainer)

View File

@@ -37,12 +37,13 @@ class LinearWeightUpdate(Callback):
check_consistency(self.initial_value, (float, int), subclass=False)
check_consistency(self.target_value, (float, int), subclass=False)
def on_train_start(self, trainer, solver):
def on_train_start(self, trainer, pl_module):
"""
Initialize the weight of the condition to the specified `initial_value`.
:param Trainer trainer: a pina:class:`Trainer` instance.
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
"""
# Check that the target epoch is valid
if not 0 < self.target_epoch <= trainer.max_epochs:
@@ -52,7 +53,7 @@ class LinearWeightUpdate(Callback):
)
# Check that the condition is a problem condition
if self.condition_name not in solver.problem.conditions:
if self.condition_name not in pl_module.problem.conditions:
raise ValueError(
f"`{self.condition_name}` must be a problem condition."
)
@@ -66,20 +67,21 @@ class LinearWeightUpdate(Callback):
)
# Check that the weighting schema is ScalarWeighting
if not isinstance(solver.weighting, ScalarWeighting):
if not isinstance(pl_module.weighting, ScalarWeighting):
raise ValueError("The weighting schema must be ScalarWeighting.")
# Initialize the weight of the condition
solver.weighting.weights[self.condition_name] = self.initial_value
pl_module.weighting.weights[self.condition_name] = self.initial_value
def on_train_epoch_start(self, trainer, solver):
def on_train_epoch_start(self, trainer, pl_module):
"""
Adjust at each epoch the weight of the condition.
:param Trainer trainer: a pina:class:`Trainer` instance.
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
"""
if 0 < trainer.current_epoch <= self.target_epoch:
solver.weighting.weights[self.condition_name] += (
pl_module.weighting.weights[self.condition_name] += (
self.target_value - self.initial_value
) / (self.target_epoch - 1)

View File

@@ -1,27 +1,27 @@
"""PINA Callbacks Implementations"""
from lightning.pytorch.callbacks import Callback
import torch
from ..optim import TorchOptimizer
from ..utils import check_consistency
from pina.optim import TorchOptimizer
class SwitchOptimizer(Callback):
"""
PINA Implementation of a Lightning Callback to switch optimizer during
training.
"""
def __init__(self, new_optimizers, epoch_switch):
"""
PINA Implementation of a Lightning Callback to switch optimizer during
training.
This callback allows for switching between different optimizers during
This callback allows switching between different optimizers during
training, enabling the exploration of multiple optimization strategies
without the need to stop training.
without interrupting the training process.
:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` or a list of them for multiple
model solver.
single :class:`torch.optim.Optimizer` instance or a list of them
for multiple model solver.
:type new_optimizers: pina.optim.TorchOptimizer | list
:param epoch_switch: The epoch at which to switch to the new optimizer.
:param epoch_switch: The epoch at which the optimizer switch occurs.
:type epoch_switch: int
Example:
@@ -46,7 +46,7 @@ class SwitchOptimizer(Callback):
def on_train_epoch_start(self, trainer, __):
"""
Callback function to switch optimizer at the start of each training epoch.
Switch the optimizer at the start of the specified training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
@@ -59,7 +59,7 @@ class SwitchOptimizer(Callback):
optims = []
for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver.models[idx].parameters())
optims.append(optim.instance)
optim.hook(trainer.solver._pina_models[idx].parameters())
optims.append(optim)
trainer.optimizers = optims
trainer.solver._pina_optimizers = optims

View File

@@ -1,7 +1,7 @@
"""PINA Callbacks Implementations"""
import torch
import copy
import torch
from lightning.pytorch.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import (
@@ -11,22 +11,37 @@ from pina.utils import check_consistency
class MetricTracker(Callback):
"""
Lightning Callback for Metric Tracking.
"""
def __init__(self, metrics_to_track=None):
"""
Lightning Callback for Metric Tracking.
Tracks specified metrics during training.
Tracks specific metrics during the training process.
:ivar _collection: A list to store collected metrics after each epoch.
:param metrics_to_track: List of metrics to track. Defaults to train/val loss.
:type metrics_to_track: list, optional
:param metrics_to_track: List of metrics to track.
Defaults to train loss.
:type metrics_to_track: list[str], optional
"""
super().__init__()
self._collection = []
# Default to tracking 'train_loss' and 'val_loss' if not specified
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"]
# Default to tracking 'train_loss' if not specified
self.metrics_to_track = metrics_to_track
def setup(self, trainer, pl_module, stage):
"""
Called when fit, validate, test, predict, or tune begins.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
:param str stage: Either 'fit', 'test' or 'predict'.
"""
if self.metrics_to_track is None and trainer.batch_size is None:
self.metrics_to_track = ["train_loss"]
elif self.metrics_to_track is None:
self.metrics_to_track = ["train_loss_epoch"]
return super().setup(trainer, pl_module, stage)
def on_train_epoch_end(self, trainer, pl_module):
"""
@@ -71,26 +86,28 @@ class MetricTracker(Callback):
class PINAProgressBar(TQDMProgressBar):
"""
PINA Implementation of a Lightning Callback for enriching the progress bar.
"""
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
BAR_FORMAT = (
"{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, "
"{rate_noinv_fmt}{postfix}]"
)
def __init__(self, metrics="val", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.
This class enables the display of only relevant metrics during training.
This class provides functionality to display only relevant metrics
during the training process.
:param metrics: Logged metrics to display during the training. It should
be a subset of the conditions keys defined in
:param metrics: Logged metrics to be shown during the training.
Must be a subset of the conditions keys defined in
:obj:`pina.condition.Condition`.
:type metrics: str | list(str) | tuple(str)
:Keyword Arguments:
The additional keyword arguments specify the progress bar
and can be choosen from the `pytorch-lightning
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
The additional keyword arguments specify the progress bar and can be
choosen from the `pytorch-lightning TQDMProgressBar API
<https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
Example:
>>> pbar = PINAProgressBar(['mean'])
@@ -105,9 +122,9 @@ class PINAProgressBar(TQDMProgressBar):
self._sorted_metrics = metrics
def get_metrics(self, trainer, pl_module):
r"""Combines progress bar metrics collected from the trainer with
r"""Combine progress bar metrics collected from the trainer with
standard metrics from get_standard_metrics.
Implement this to override the items displayed in the progress bar.
Override this method to customize the items shown in the progress bar.
The progress bar metrics are sorted according to ``metrics``.
Here is an example of how to override the defaults:
@@ -122,20 +139,20 @@ class PINAProgressBar(TQDMProgressBar):
:return: Dictionary with the items to be displayed in the progress bar.
:rtype: tuple(dict)
"""
standard_metrics = get_standard_metrics(trainer)
pbar_metrics = trainer.progress_bar_metrics
if pbar_metrics:
pbar_metrics = {
key: pbar_metrics[key] for key in self._sorted_metrics
key: pbar_metrics[key]
for key in pbar_metrics
if key in self._sorted_metrics
}
return {**standard_metrics, **pbar_metrics}
def on_fit_start(self, trainer, pl_module):
def setup(self, trainer, pl_module, stage):
"""
Check that the metrics defined in the initialization are available,
i.e. are correctly logged.
Check that the initialized metrics are available and correctly logged.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
@@ -150,7 +167,11 @@ class PINAProgressBar(TQDMProgressBar):
):
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix
if trainer.batch_size is not None:
pedix = "_loss_epoch"
else:
pedix = "_loss"
self._sorted_metrics = [
metric + "_loss" for metric in self._sorted_metrics
metric + pedix for metric in self._sorted_metrics
]
return super().on_fit_start(trainer, pl_module)
return super().setup(trainer, pl_module, stage)

View File

@@ -64,8 +64,7 @@ class Trainer(lightning.pytorch.Trainer):
:Keyword Arguments:
The additional keyword arguments specify the training setup
and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/
trainer.html#trainer-class-api>`_
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
"""
# check consistency for init types
self._check_input_consistency(
@@ -96,7 +95,6 @@ class Trainer(lightning.pytorch.Trainer):
# Setting default kwargs, overriding lightning defaults
kwargs.setdefault("enable_progress_bar", True)
kwargs.setdefault("logger", None)
super().__init__(**kwargs)
@@ -127,9 +125,6 @@ class Trainer(lightning.pytorch.Trainer):
# logging
self.logging_kwargs = {
"logger": bool(
kwargs["logger"] is not None or kwargs["logger"] is True
),
"sync_dist": bool(
len(self._accelerator_connector._parallel_devices) > 1
),

View File

@@ -23,17 +23,18 @@ def test_metric_tracker_constructor():
MetricTracker()
# def test_metric_tracker_routine(): #TODO revert
# # make the trainer
# trainer = Trainer(solver=solver,
# callback=[
# MetricTracker()
# ],
# accelerator='cpu',
# max_epochs=5)
# trainer.train()
# # get the tracked metrics
# metrics = trainer.callback[0].metrics
# # assert the logged metrics are correct
# logged_metrics = sorted(list(metrics.keys()))
# assert logged_metrics == ['train_loss_epoch', 'train_loss_step', 'val_loss']
def test_metric_tracker_routine():
# make the trainer
trainer = Trainer(
solver=solver,
callbacks=[MetricTracker()],
accelerator="cpu",
max_epochs=5,
log_every_n_steps=1,
)
trainer.train()
# get the tracked metrics
metrics = trainer.callbacks[0].metrics
# assert the logged metrics are correct
logged_metrics = sorted(list(metrics.keys()))
assert logged_metrics == ["train_loss"]

View File

@@ -21,19 +21,25 @@ model = FeedForward(
# make the solver
solver = PINN(problem=poisson_problem, model=model)
adam_optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
lbfgs_optimizer = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
adam = TorchOptimizer(torch.optim.Adam, lr=0.01)
lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
def test_switch_optimizer_constructor():
SwitchOptimizer(adam_optimizer, epoch_switch=10)
SwitchOptimizer(adam, epoch_switch=10)
# def test_switch_optimizer_routine(): #TODO revert
# # make the trainer
# switch_opt_callback = SwitchOptimizer(lbfgs_optimizer, epoch_switch=3)
# trainer = Trainer(solver=solver,
# callback=[switch_opt_callback],
# accelerator='cpu',
# max_epochs=5)
# trainer.train()
def test_switch_optimizer_routine():
# check initial optimizer
solver.configure_optimizers()
assert solver.optimizer.instance.__class__ == torch.optim.Adam
# make the trainer
switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3)
trainer = Trainer(
solver=solver,
callbacks=[switch_opt_callback],
accelerator="cpu",
max_epochs=5,
)
trainer.train()
assert solver.optimizer.instance.__class__ == torch.optim.LBFGS

View File

@@ -5,29 +5,32 @@ from pina.callback.processing_callback import PINAProgressBar
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
# # make the problem
# poisson_problem = Poisson()
# boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4']
# n = 10
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
# poisson_problem.discretise_domain(n, 'grid', locations='laplace_D')
# model = FeedForward(len(poisson_problem.input_variables),
# len(poisson_problem.output_variables))
# make the problem
poisson_problem = Poisson()
boundaries = ["g1", "g2", "g3", "g4"]
n = 10
condition_names = list(poisson_problem.conditions.keys())
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
poisson_problem.discretise_domain(n, "grid", domains="D")
model = FeedForward(
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
)
# # make the solver
# solver = PINN(problem=poisson_problem, model=model)
# make the solver
solver = PINN(problem=poisson_problem, model=model)
# def test_progress_bar_constructor():
# PINAProgressBar(['mean'])
def test_progress_bar_constructor():
PINAProgressBar()
# def test_progress_bar_routine():
# # make the trainer
# trainer = Trainer(solver=solver,
# callback=[
# PINAProgressBar(['mean', 'laplace_D'])
# ],
# accelerator='cpu',
# max_epochs=5)
# trainer.train()
# # TODO there should be a check that the correct metrics are displayed
def test_progress_bar_routine():
# make the trainer
trainer = Trainer(
solver=solver,
callbacks=[PINAProgressBar(["val", condition_names[0]])],
accelerator="cpu",
max_epochs=5,
)
trainer.train()
# TODO there should be a check that the correct metrics are displayed