Adaptive Refinment and Multiple Optimizer callbacks

* Implementing a callback to switch between optimizers during training
* Implementing the R3Refinment for collocation points
* Modify trainer -> dataloader is created or updated by calling `_create_or_update_loader`
* Adding `add_points` routine to AbstractProblem so that new points can be added without resampling from scratch
This commit is contained in:
Dario Coscia
2023-09-14 18:37:02 +02:00
committed by Nicola Demo
parent 5a4c114d48
commit 4d1187898f
3 changed files with 229 additions and 4 deletions

190
pina/callbacks.py Normal file
View File

@@ -0,0 +1,190 @@
'''PINA Callbacks Implementations'''
from lightning.pytorch.callbacks import Callback
import torch
from .utils import check_consistency
class SwitchOptimizer(Callback):
"""
PINA implementation of a Lightining Callback to switch
optimizer during training. The rouutine can be used to
try multiple optimizers during the training, without the
need to stop training.
"""
def __init__(self, new_optimizers, new_optimizers_kargs, epoch_switch):
"""
SwitchOptimizer is a routine for switching optimizer during training.
:param torch.optim.Optimizer | list new_optimizers: The model optimizers to
switch to. It must be a list of :class:`torch.optim.Optimizer` or list of
:class:`torch.optim.Optimizer` for multiple model solvers.
:param dict| list new_optimizers: The model optimizers keyword arguments to
switch use. It must be a dict or list of dict for multiple optimizers.
:param int epoch_switch: Epoch for switching optimizer.
"""
super().__init__()
# check type consistency
check_consistency(new_optimizers, torch.optim.Optimizer, subclass=True)
check_consistency(new_optimizers_kargs, dict)
check_consistency(epoch_switch, int)
if epoch_switch < 1:
raise ValueError('epoch_switch must be greater than one.')
if not isinstance(new_optimizers, list):
optimizers = [new_optimizers]
optimizers_kwargs = [new_optimizers_kargs]
len_optimizer = len(optimizers)
len_optimizer_kwargs = len(optimizers_kwargs)
if len_optimizer_kwargs != len_optimizer:
raise ValueError('You must define one dictionary of keyword'
' arguments for each optimizers.'
f'Got {len_optimizer} optimizers, and'
f' {len_optimizer_kwargs} dicitionaries')
# save new optimizers
self._new_optimizers = optimizers
self._new_optimizers_kwargs = optimizers_kwargs
self._epoch_switch = epoch_switch
def on_train_epoch_start(self, trainer, __):
if trainer.current_epoch == self._epoch_switch:
optims = []
for idx, (optim, optim_kwargs) in enumerate(
zip(self._new_optimizers,
self._new_optimizers_kwargs)
):
optims.append(optim(trainer._model.models[idx].parameters(), **optim_kwargs))
trainer.optimizers = optims
class R3Refinement(Callback):
"""
PINA implementation of a R3 Refinement Callback.
.. seealso::
**Original reference**: Daw, Arka, et al. "Mitigating Propagation Failures
in Physics-informed Neural Networks using
Retain-Resample-Release (R3) Sampling." (2023).
DOI: `10.48550/arXiv.2207.02338
< https://doi.org/10.48550/arXiv.2207.02338>`_
"""
def __init__(self, sample_every):
"""
R3 routine for sampling new points based on
adpative search. The algorithm incrementally
accumulate collocation points in regions of
high PDE residuals, and release the one which
have low residual. Points are sampled uniformmaly
in all region where sampling is needed.
:param int sample_every: Frequency for sampling.
"""
super().__init__()
# sample every
check_consistency(sample_every, int)
self._sample_every = sample_every
def _compute_residual(self, trainer):
"""
Computes the residuals for a PINN object.
:return: the total loss, and pointwise loss.
:rtype: tuple
"""
# extract the solver and device from trainer
solver = trainer._model
device = trainer._accelerator_connector._accelerator_flag
# compute residual
res_loss = {}
tot_loss = []
for location in self._sampling_locations:
condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location]
# send points to correct device
pts = pts.to(device)
pts = pts.requires_grad_(True)
pts.retain_grad()
# PINN loss: equation evaluated only on locations where sampling is needed
target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target)
tot_loss.append(torch.abs(target))
return torch.vstack(tot_loss), res_loss
def _r3_routine(self, trainer):
"""
R3 refinement main routine.
:param Trainer trainer: PINA Trainer.
"""
# compute residual (all device possible)
tot_loss, res_loss = self._compute_residual(trainer)
# !!!!!! From now everything is performed on CPU !!!!!!
# average loss
avg = (tot_loss.mean()).to('cpu')
# points to keep
old_pts = {}
tot_points = 0
for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location]
labels = pts.labels
pts = pts.cpu().detach()
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
# TODO masking remove labels
pts = pts[mask]
pts.labels = labels
####
old_pts[location] = pts
tot_points += len(pts)
# extract new points to sample uniformally for each location
n_points = (self._tot_pop_numb - tot_points ) // len(self._sampling_locations)
remainder = (self._tot_pop_numb - tot_points ) % len(self._sampling_locations)
n_uniform_points = [n_points] * len(self._sampling_locations)
n_uniform_points[-1] += remainder
# sample new points
for numb_pts, loc in zip(n_uniform_points, self._sampling_locations):
trainer._model.problem.discretise_domain(numb_pts,
'random',
locations=[loc])
# adding previous population points
trainer._model.problem.add_points(old_pts)
# update dataloader
trainer._create_or_update_loader()
def on_train_start(self, trainer, _):
# extract locations for sampling
problem = trainer._model.problem
locations = []
for condition_name in problem.conditions:
condition = problem.conditions[condition_name]
if hasattr(condition, 'location'):
locations.append(condition_name)
self._sampling_locations = locations
# extract total population
total_population = 0
for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location]
total_population += len(pts)
self._tot_pop_numb = total_population
def on_train_epoch_end(self, trainer, __):
if trainer.current_epoch % self._sample_every == 0:
self._r3_routine(trainer)

View File

@@ -1,6 +1,7 @@
""" Module for AbstractProblem class """
from abc import ABCMeta, abstractmethod
from ..utils import merge_tensors, check_consistency
import torch
class AbstractProblem(metaclass=ABCMeta):
@@ -201,6 +202,36 @@ class AbstractProblem(metaclass=ABCMeta):
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
self._have_sampled_points[location] = True
def add_points(self, new_points):
"""
Adding points to the already sampled points
:param dict new_points: a dictionary with key the location to add the points
and values the torch.Tensor points.
"""
if sorted(new_points.keys()) != sorted(self.conditions):
TypeError(f'Wrong locations for new points. Location ',
f'should be in {self.conditions}.')
for location in new_points.keys():
# extract old and new points
old_pts = self.input_pts[location]
new_pts = new_points[location]
# if they don't have the same variables error
if sorted(old_pts.labels) != sorted(new_pts.labels):
TypeError(f'Not matching variables for old and new points '
f'in condition {location}.')
if old_pts.labels != new_pts.labels:
new_pts = torch.hstack([new_pts.extract([i]) for i in old_pts.labels])
new_pts.labels = old_pts.labels
# merging
merged_pts = torch.vstack([old_pts, new_points[location]])
merged_pts.labels = old_pts.labels
self.input_pts[location] = merged_pts
@property
def have_sampled_points(self):
"""

View File

@@ -10,9 +10,6 @@ class Trainer(pl.Trainer):
def __init__(self, solver, **kwargs):
super().__init__(**kwargs)
# get accellerator
device = self._accelerator_connector._accelerator_flag
# check inheritance consistency for solver
check_consistency(solver, SolverInterface)
self._model = solver
@@ -26,8 +23,15 @@ class Trainer(pl.Trainer):
'in the provided locations.')
# TODO: make a better dataloader for train
self._loader = DummyLoader(solver.problem.input_pts, device)
self._create_or_update_loader()
# this method is used here because is resampling is needed
# during training, there is no need to define to touch the
# trainer dataloader, just call the method.
def _create_or_update_loader(self):
# get accellerator
device = self._accelerator_connector._accelerator_flag
self._loader = DummyLoader(self._model.problem.input_pts, device)
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
return super().fit(self._model, self._loader, **kwargs)