enable lr_scheduler usage

This commit is contained in:
Francesco Andreuzzi
2022-12-25 17:54:12 +01:00
committed by Nicola Demo
parent 18c375887c
commit 5c09ff626c
2 changed files with 46 additions and 1 deletions

View File

@@ -118,6 +118,37 @@ def test_train():
assert list(pinn.history_loss.keys()) == truth_key
def test_train_with_optimizer_kwargs():
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
expected_keys = [[], list(range(0, 50, 3))]
param = [0, 3]
for i, truth_key in zip(param, expected_keys):
pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3})
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(50, save_loss=i)
assert list(pinn.history_loss.keys()) == truth_key
def test_train_with_lr_scheduler():
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
expected_keys = [[], list(range(0, 50, 3))]
param = [0, 3]
for i, truth_key in zip(param, expected_keys):
pinn = PINN(
problem,
model,
lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
)
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(50, save_loss=i)
assert list(pinn.history_loss.keys()) == truth_key
def test_train_batch():
pinn = PINN(problem, model, batch_size=6)
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']