From 1ed14916f1309633c6ac39bd257398aa52fd3829 Mon Sep 17 00:00:00 2001 From: Giovanni Canali Date: Wed, 23 Jul 2025 17:42:25 +0200 Subject: [PATCH] fix switch_optimizer callback --- pina/callback/optimizer_callback.py | 31 +++++--- .../test_callback/test_optimizer_callback.py | 74 ++++++++++++------- 2 files changed, 65 insertions(+), 40 deletions(-) diff --git a/pina/callback/optimizer_callback.py b/pina/callback/optimizer_callback.py index fb2770a..1b51840 100644 --- a/pina/callback/optimizer_callback.py +++ b/pina/callback/optimizer_callback.py @@ -21,26 +21,30 @@ class SwitchOptimizer(Callback): single :class:`torch.optim.Optimizer` instance or a list of them for multiple model solver. :type new_optimizers: pina.optim.TorchOptimizer | list - :param epoch_switch: The epoch at which the optimizer switch occurs. - :type epoch_switch: int + :param int epoch_switch: The epoch at which the optimizer switch occurs. Example: - >>> switch_callback = SwitchOptimizer(new_optimizers=optimizer, - >>> epoch_switch=10) + >>> optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01) + >>> switch_callback = SwitchOptimizer( + >>> new_optimizers=optimizer, epoch_switch=10 + >>> ) """ super().__init__() + # Check if epoch_switch is greater than 1 if epoch_switch < 1: raise ValueError("epoch_switch must be greater than one.") + # If new_optimizers is not a list, convert it to a list if not isinstance(new_optimizers, list): new_optimizers = [new_optimizers] - # check type consistency + # Check consistency + check_consistency(epoch_switch, int) for optimizer in new_optimizers: check_consistency(optimizer, TorchOptimizer) - check_consistency(epoch_switch, int) - # save new optimizers + + # Store the new optimizers and epoch switch self._new_optimizers = new_optimizers self._epoch_switch = epoch_switch @@ -48,18 +52,21 @@ class SwitchOptimizer(Callback): """ Switch the optimizer at the start of the specified training epoch. - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer + :param lightning.pytorch.Trainer trainer: The trainer object managing + the training process. :param _: Placeholder argument (not used). - - :return: None - :rtype: None """ + # Check if the current epoch matches the switch epoch if trainer.current_epoch == self._epoch_switch: optims = [] + # Hook the new optimizers to the model parameters for idx, optim in enumerate(self._new_optimizers): optim.hook(trainer.solver._pina_models[idx].parameters()) optims.append(optim) + # Update the solver's optimizers trainer.solver._pina_optimizers = optims + + # Update the trainer's strategy optimizers + trainer.strategy.optimizers = [o.instance for o in optims] diff --git a/tests/test_callback/test_optimizer_callback.py b/tests/test_callback/test_optimizer_callback.py index 785a9c3..3383c79 100644 --- a/tests/test_callback/test_optimizer_callback.py +++ b/tests/test_callback/test_optimizer_callback.py @@ -1,45 +1,63 @@ -from pina.callback import SwitchOptimizer import torch import pytest from pina.solver import PINN from pina.trainer import Trainer from pina.model import FeedForward -from pina.problem.zoo import Poisson2DSquareProblem as Poisson from pina.optim import TorchOptimizer - -# make the problem -poisson_problem = Poisson() -boundaries = ["g1", "g2", "g3", "g4"] -n = 10 -poisson_problem.discretise_domain(n, "grid", domains=boundaries) -poisson_problem.discretise_domain(n, "grid", domains="D") -model = FeedForward( - len(poisson_problem.input_variables), len(poisson_problem.output_variables) -) - -# make the solver -solver = PINN(problem=poisson_problem, model=model) - -adam = TorchOptimizer(torch.optim.Adam, lr=0.01) -lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001) +from pina.callback import SwitchOptimizer +from pina.problem.zoo import Poisson2DSquareProblem as Poisson -def test_switch_optimizer_constructor(): - SwitchOptimizer(adam, epoch_switch=10) +# Define the problem +problem = Poisson() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# Define the optimizer +optimizer = TorchOptimizer(torch.optim.Adam) + +# Initialize the solver +solver = PINN(problem=problem, model=model, optimizer=optimizer) + +# Define new optimizers for testing +lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=1.0) +adamW = TorchOptimizer(torch.optim.AdamW, lr=0.01) -def test_switch_optimizer_routine(): - # check initial optimizer +@pytest.mark.parametrize("epoch_switch", [5, 10]) +@pytest.mark.parametrize("new_opt", [lbfgs, adamW]) +def test_switch_optimizer_constructor(new_opt, epoch_switch): + + # Constructor + SwitchOptimizer(new_optimizers=new_opt, epoch_switch=epoch_switch) + + # Should fail if epoch_switch is less than 1 + with pytest.raises(ValueError): + SwitchOptimizer(new_optimizers=new_opt, epoch_switch=0) + + +@pytest.mark.parametrize("epoch_switch", [5, 10]) +@pytest.mark.parametrize("new_opt", [lbfgs, adamW]) +def test_switch_optimizer_routine(new_opt, epoch_switch): + + # Check if the optimizer is initialized correctly solver.configure_optimizers() - assert solver.optimizer.instance.__class__ == torch.optim.Adam - # make the trainer - switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3) + + # Initialize the trainer + switch_opt_callback = SwitchOptimizer( + new_optimizers=new_opt, epoch_switch=epoch_switch + ) trainer = Trainer( solver=solver, - callbacks=[switch_opt_callback], + callbacks=switch_opt_callback, accelerator="cpu", - max_epochs=5, + max_epochs=epoch_switch + 2, ) trainer.train() - assert solver.optimizer.instance.__class__ == torch.optim.LBFGS + + # Check that the trainer strategy optimizers have been updated + assert solver.optimizer.instance.__class__ == new_opt.instance.__class__ + assert ( + trainer.strategy.optimizers[0].__class__ == new_opt.instance.__class__ + )