minor fix

This commit is contained in:
Dario Coscia
2023-09-19 15:13:50 +02:00
committed by Nicola Demo
parent 4d1187898f
commit 1936133ad5
5 changed files with 222 additions and 57 deletions

View File

@@ -0,0 +1,62 @@
'''PINA Callbacks Implementations'''
from lightning.pytorch.callbacks import Callback
import torch
from ..utils import check_consistency
class SwitchOptimizer(Callback):
"""
PINA implementation of a Lightining Callback to switch
optimizer during training. The rouutine can be used to
try multiple optimizers during the training, without the
need to stop training.
"""
def __init__(self, new_optimizers, new_optimizers_kwargs, epoch_switch):
"""
SwitchOptimizer is a routine for switching optimizer during training.
:param torch.optim.Optimizer | list new_optimizers: The model optimizers to
switch to. It must be a list of :class:`torch.optim.Optimizer` or list of
:class:`torch.optim.Optimizer` for multiple model solvers.
:param dict| list new_optimizers: The model optimizers keyword arguments to
switch use. It must be a dict or list of dict for multiple optimizers.
:param int epoch_switch: Epoch for switching optimizer.
"""
super().__init__()
# check type consistency
check_consistency(new_optimizers, torch.optim.Optimizer, subclass=True)
check_consistency(new_optimizers_kwargs, dict)
check_consistency(epoch_switch, int)
if epoch_switch < 1:
raise ValueError('epoch_switch must be greater than one.')
if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers]
new_optimizers_kwargs = [new_optimizers_kwargs]
len_optimizer = len(new_optimizers)
len_optimizer_kwargs = len(new_optimizers_kwargs)
if len_optimizer_kwargs != len_optimizer:
raise ValueError('You must define one dictionary of keyword'
' arguments for each optimizers.'
f' Got {len_optimizer} optimizers, and'
f' {len_optimizer_kwargs} dicitionaries')
# save new optimizers
self._new_optimizers = new_optimizers
self._new_optimizers_kwargs = new_optimizers_kwargs
self._epoch_switch = epoch_switch
def on_train_epoch_start(self, trainer, __):
if trainer.current_epoch == self._epoch_switch:
optims = []
for idx, (optim, optim_kwargs) in enumerate(
zip(self._new_optimizers,
self._new_optimizers_kwargs)
):
optims.append(optim(trainer._model.models[idx].parameters(), **optim_kwargs))
trainer.optimizers = optims