fix switch_optimizer callback

This commit is contained in:
Giovanni Canali
2025-07-23 17:42:25 +02:00
committed by Giovanni Canali
parent 6d10989d89
commit 1ed14916f1
2 changed files with 65 additions and 40 deletions

View File

@@ -21,26 +21,30 @@ class SwitchOptimizer(Callback):
single :class:`torch.optim.Optimizer` instance or a list of them single :class:`torch.optim.Optimizer` instance or a list of them
for multiple model solver. for multiple model solver.
:type new_optimizers: pina.optim.TorchOptimizer | list :type new_optimizers: pina.optim.TorchOptimizer | list
:param epoch_switch: The epoch at which the optimizer switch occurs. :param int epoch_switch: The epoch at which the optimizer switch occurs.
:type epoch_switch: int
Example: Example:
>>> switch_callback = SwitchOptimizer(new_optimizers=optimizer, >>> optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
>>> epoch_switch=10) >>> switch_callback = SwitchOptimizer(
>>> new_optimizers=optimizer, epoch_switch=10
>>> )
""" """
super().__init__() super().__init__()
# Check if epoch_switch is greater than 1
if epoch_switch < 1: if epoch_switch < 1:
raise ValueError("epoch_switch must be greater than one.") raise ValueError("epoch_switch must be greater than one.")
# If new_optimizers is not a list, convert it to a list
if not isinstance(new_optimizers, list): if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers] new_optimizers = [new_optimizers]
# check type consistency # Check consistency
check_consistency(epoch_switch, int)
for optimizer in new_optimizers: for optimizer in new_optimizers:
check_consistency(optimizer, TorchOptimizer) check_consistency(optimizer, TorchOptimizer)
check_consistency(epoch_switch, int)
# save new optimizers # Store the new optimizers and epoch switch
self._new_optimizers = new_optimizers self._new_optimizers = new_optimizers
self._epoch_switch = epoch_switch self._epoch_switch = epoch_switch
@@ -48,18 +52,21 @@ class SwitchOptimizer(Callback):
""" """
Switch the optimizer at the start of the specified training epoch. Switch the optimizer at the start of the specified training epoch.
:param trainer: The trainer object managing the training process. :param lightning.pytorch.Trainer trainer: The trainer object managing
:type trainer: pytorch_lightning.Trainer the training process.
:param _: Placeholder argument (not used). :param _: Placeholder argument (not used).
:return: None
:rtype: None
""" """
# Check if the current epoch matches the switch epoch
if trainer.current_epoch == self._epoch_switch: if trainer.current_epoch == self._epoch_switch:
optims = [] optims = []
# Hook the new optimizers to the model parameters
for idx, optim in enumerate(self._new_optimizers): for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver._pina_models[idx].parameters()) optim.hook(trainer.solver._pina_models[idx].parameters())
optims.append(optim) optims.append(optim)
# Update the solver's optimizers
trainer.solver._pina_optimizers = optims trainer.solver._pina_optimizers = optims
# Update the trainer's strategy optimizers
trainer.strategy.optimizers = [o.instance for o in optims]

View File

@@ -1,45 +1,63 @@
from pina.callback import SwitchOptimizer
import torch import torch
import pytest import pytest
from pina.solver import PINN from pina.solver import PINN
from pina.trainer import Trainer from pina.trainer import Trainer
from pina.model import FeedForward from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.optim import TorchOptimizer from pina.optim import TorchOptimizer
from pina.callback import SwitchOptimizer
# make the problem from pina.problem.zoo import Poisson2DSquareProblem as Poisson
poisson_problem = Poisson()
boundaries = ["g1", "g2", "g3", "g4"]
n = 10
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
poisson_problem.discretise_domain(n, "grid", domains="D")
model = FeedForward(
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
)
# make the solver
solver = PINN(problem=poisson_problem, model=model)
adam = TorchOptimizer(torch.optim.Adam, lr=0.01)
lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
def test_switch_optimizer_constructor(): # Define the problem
SwitchOptimizer(adam, epoch_switch=10) problem = Poisson()
problem.discretise_domain(10)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
# Define the optimizer
optimizer = TorchOptimizer(torch.optim.Adam)
# Initialize the solver
solver = PINN(problem=problem, model=model, optimizer=optimizer)
# Define new optimizers for testing
lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=1.0)
adamW = TorchOptimizer(torch.optim.AdamW, lr=0.01)
def test_switch_optimizer_routine(): @pytest.mark.parametrize("epoch_switch", [5, 10])
# check initial optimizer @pytest.mark.parametrize("new_opt", [lbfgs, adamW])
def test_switch_optimizer_constructor(new_opt, epoch_switch):
# Constructor
SwitchOptimizer(new_optimizers=new_opt, epoch_switch=epoch_switch)
# Should fail if epoch_switch is less than 1
with pytest.raises(ValueError):
SwitchOptimizer(new_optimizers=new_opt, epoch_switch=0)
@pytest.mark.parametrize("epoch_switch", [5, 10])
@pytest.mark.parametrize("new_opt", [lbfgs, adamW])
def test_switch_optimizer_routine(new_opt, epoch_switch):
# Check if the optimizer is initialized correctly
solver.configure_optimizers() solver.configure_optimizers()
assert solver.optimizer.instance.__class__ == torch.optim.Adam
# make the trainer # Initialize the trainer
switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3) switch_opt_callback = SwitchOptimizer(
new_optimizers=new_opt, epoch_switch=epoch_switch
)
trainer = Trainer( trainer = Trainer(
solver=solver, solver=solver,
callbacks=[switch_opt_callback], callbacks=switch_opt_callback,
accelerator="cpu", accelerator="cpu",
max_epochs=5, max_epochs=epoch_switch + 2,
) )
trainer.train() trainer.train()
assert solver.optimizer.instance.__class__ == torch.optim.LBFGS
# Check that the trainer strategy optimizers have been updated
assert solver.optimizer.instance.__class__ == new_opt.instance.__class__
assert (
trainer.strategy.optimizers[0].__class__ == new_opt.instance.__class__
)