'''PINA Callbacks Implementations''' from lightning.pytorch.callbacks import Callback import torch from .utils import check_consistency class SwitchOptimizer(Callback): """ PINA implementation of a Lightining Callback to switch optimizer during training. The rouutine can be used to try multiple optimizers during the training, without the need to stop training. """ def __init__(self, new_optimizers, new_optimizers_kargs, epoch_switch): """ SwitchOptimizer is a routine for switching optimizer during training. :param torch.optim.Optimizer | list new_optimizers: The model optimizers to switch to. It must be a list of :class:`torch.optim.Optimizer` or list of :class:`torch.optim.Optimizer` for multiple model solvers. :param dict| list new_optimizers: The model optimizers keyword arguments to switch use. It must be a dict or list of dict for multiple optimizers. :param int epoch_switch: Epoch for switching optimizer. """ super().__init__() # check type consistency check_consistency(new_optimizers, torch.optim.Optimizer, subclass=True) check_consistency(new_optimizers_kargs, dict) check_consistency(epoch_switch, int) if epoch_switch < 1: raise ValueError('epoch_switch must be greater than one.') if not isinstance(new_optimizers, list): optimizers = [new_optimizers] optimizers_kwargs = [new_optimizers_kargs] len_optimizer = len(optimizers) len_optimizer_kwargs = len(optimizers_kwargs) if len_optimizer_kwargs != len_optimizer: raise ValueError('You must define one dictionary of keyword' ' arguments for each optimizers.' f'Got {len_optimizer} optimizers, and' f' {len_optimizer_kwargs} dicitionaries') # save new optimizers self._new_optimizers = optimizers self._new_optimizers_kwargs = optimizers_kwargs self._epoch_switch = epoch_switch def on_train_epoch_start(self, trainer, __): if trainer.current_epoch == self._epoch_switch: optims = [] for idx, (optim, optim_kwargs) in enumerate( zip(self._new_optimizers, self._new_optimizers_kwargs) ): optims.append(optim(trainer._model.models[idx].parameters(), **optim_kwargs)) trainer.optimizers = optims class R3Refinement(Callback): """ PINA implementation of a R3 Refinement Callback. .. seealso:: **Original reference**: Daw, Arka, et al. "Mitigating Propagation Failures in Physics-informed Neural Networks using Retain-Resample-Release (R3) Sampling." (2023). DOI: `10.48550/arXiv.2207.02338 < https://doi.org/10.48550/arXiv.2207.02338>`_ """ def __init__(self, sample_every): """ R3 routine for sampling new points based on adpative search. The algorithm incrementally accumulate collocation points in regions of high PDE residuals, and release the one which have low residual. Points are sampled uniformmaly in all region where sampling is needed. :param int sample_every: Frequency for sampling. """ super().__init__() # sample every check_consistency(sample_every, int) self._sample_every = sample_every def _compute_residual(self, trainer): """ Computes the residuals for a PINN object. :return: the total loss, and pointwise loss. :rtype: tuple """ # extract the solver and device from trainer solver = trainer._model device = trainer._accelerator_connector._accelerator_flag # compute residual res_loss = {} tot_loss = [] for location in self._sampling_locations: condition = solver.problem.conditions[location] pts = solver.problem.input_pts[location] # send points to correct device pts = pts.to(device) pts = pts.requires_grad_(True) pts.retain_grad() # PINN loss: equation evaluated only on locations where sampling is needed target = condition.equation.residual(pts, solver.forward(pts)) res_loss[location] = torch.abs(target) tot_loss.append(torch.abs(target)) return torch.vstack(tot_loss), res_loss def _r3_routine(self, trainer): """ R3 refinement main routine. :param Trainer trainer: PINA Trainer. """ # compute residual (all device possible) tot_loss, res_loss = self._compute_residual(trainer) # !!!!!! From now everything is performed on CPU !!!!!! # average loss avg = (tot_loss.mean()).to('cpu') # points to keep old_pts = {} tot_points = 0 for location in self._sampling_locations: pts = trainer._model.problem.input_pts[location] labels = pts.labels pts = pts.cpu().detach() residuals = res_loss[location].cpu() mask = (residuals > avg).flatten() # TODO masking remove labels pts = pts[mask] pts.labels = labels #### old_pts[location] = pts tot_points += len(pts) # extract new points to sample uniformally for each location n_points = (self._tot_pop_numb - tot_points ) // len(self._sampling_locations) remainder = (self._tot_pop_numb - tot_points ) % len(self._sampling_locations) n_uniform_points = [n_points] * len(self._sampling_locations) n_uniform_points[-1] += remainder # sample new points for numb_pts, loc in zip(n_uniform_points, self._sampling_locations): trainer._model.problem.discretise_domain(numb_pts, 'random', locations=[loc]) # adding previous population points trainer._model.problem.add_points(old_pts) # update dataloader trainer._create_or_update_loader() def on_train_start(self, trainer, _): # extract locations for sampling problem = trainer._model.problem locations = [] for condition_name in problem.conditions: condition = problem.conditions[condition_name] if hasattr(condition, 'location'): locations.append(condition_name) self._sampling_locations = locations # extract total population total_population = 0 for location in self._sampling_locations: pts = trainer._model.problem.input_pts[location] total_population += len(pts) self._tot_pop_numb = total_population def on_train_epoch_end(self, trainer, __): if trainer.current_epoch % self._sample_every == 0: self._r3_routine(trainer)