Update callbacks and tests (#482)

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
Dario Coscia
2025-03-13 16:19:38 +01:00
committed by Nicola Demo
parent 6ae301622b
commit 632934f9cc
8 changed files with 264 additions and 229 deletions

View File

@@ -1,27 +1,27 @@
"""PINA Callbacks Implementations"""
from lightning.pytorch.callbacks import Callback
import torch
from ..optim import TorchOptimizer
from ..utils import check_consistency
from pina.optim import TorchOptimizer
class SwitchOptimizer(Callback):
"""
PINA Implementation of a Lightning Callback to switch optimizer during
training.
"""
def __init__(self, new_optimizers, epoch_switch):
"""
PINA Implementation of a Lightning Callback to switch optimizer during
training.
This callback allows for switching between different optimizers during
This callback allows switching between different optimizers during
training, enabling the exploration of multiple optimization strategies
without the need to stop training.
without interrupting the training process.
:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` or a list of them for multiple
model solver.
single :class:`torch.optim.Optimizer` instance or a list of them
for multiple model solver.
:type new_optimizers: pina.optim.TorchOptimizer | list
:param epoch_switch: The epoch at which to switch to the new optimizer.
:param epoch_switch: The epoch at which the optimizer switch occurs.
:type epoch_switch: int
Example:
@@ -46,7 +46,7 @@ class SwitchOptimizer(Callback):
def on_train_epoch_start(self, trainer, __):
"""
Callback function to switch optimizer at the start of each training epoch.
Switch the optimizer at the start of the specified training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
@@ -59,7 +59,7 @@ class SwitchOptimizer(Callback):
optims = []
for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver.models[idx].parameters())
optims.append(optim.instance)
optim.hook(trainer.solver._pina_models[idx].parameters())
optims.append(optim)
trainer.optimizers = optims
trainer.solver._pina_optimizers = optims