Update callbacks and tests (#482)
--------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
committed by
FilippoOlivo
parent
18d178ab3a
commit
9dab6380f8
@@ -21,19 +21,25 @@ model = FeedForward(
|
||||
# make the solver
|
||||
solver = PINN(problem=poisson_problem, model=model)
|
||||
|
||||
adam_optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
|
||||
lbfgs_optimizer = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
|
||||
adam = TorchOptimizer(torch.optim.Adam, lr=0.01)
|
||||
lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
|
||||
|
||||
|
||||
def test_switch_optimizer_constructor():
|
||||
SwitchOptimizer(adam_optimizer, epoch_switch=10)
|
||||
SwitchOptimizer(adam, epoch_switch=10)
|
||||
|
||||
|
||||
# def test_switch_optimizer_routine(): #TODO revert
|
||||
# # make the trainer
|
||||
# switch_opt_callback = SwitchOptimizer(lbfgs_optimizer, epoch_switch=3)
|
||||
# trainer = Trainer(solver=solver,
|
||||
# callback=[switch_opt_callback],
|
||||
# accelerator='cpu',
|
||||
# max_epochs=5)
|
||||
# trainer.train()
|
||||
def test_switch_optimizer_routine():
|
||||
# check initial optimizer
|
||||
solver.configure_optimizers()
|
||||
assert solver.optimizer.instance.__class__ == torch.optim.Adam
|
||||
# make the trainer
|
||||
switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[switch_opt_callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=5,
|
||||
)
|
||||
trainer.train()
|
||||
assert solver.optimizer.instance.__class__ == torch.optim.LBFGS
|
||||
|
||||
Reference in New Issue
Block a user