From 5c09ff626c73398b871d29034a650938a03a69de Mon Sep 17 00:00:00 2001 From: Francesco Andreuzzi Date: Sun, 25 Dec 2022 17:54:12 +0100 Subject: [PATCH] enable lr_scheduler usage --- pina/pinn.py | 16 +++++++++++++++- tests/test_pinn.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/pina/pinn.py b/pina/pinn.py index afd285a..4ae9628 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -1,5 +1,6 @@ """ Module for PINN """ import torch +import torch.optim.lr_scheduler as lrs from .problem import AbstractProblem from .label_tensor import LabelTensor @@ -15,7 +16,10 @@ class PINN(object): problem, model, optimizer=torch.optim.Adam, + optimizer_kwargs=None, lr=0.001, + lr_scheduler_type=lrs.ConstantLR, + lr_scheduler_kwargs={"factor" : 1, "total_iters" : 0}, regularizer=0.00001, batch_size=None, dtype=torch.float32, @@ -26,7 +30,10 @@ class PINN(object): :param torch.nn.Module model: the neural network model to use. :param torch.optim optimizer: the neural network optimizer to use; default is `torch.optim.Adam`. + :param dict optimizer_kwargs: Optimizer constructor keyword args. :param float lr: the learning rate; default is 0.001. + :param torch.optim.lr_scheduler._LRScheduler lr_scheduler_type: Learning rate scheduler. + :param dict lr_scheduler_kwargs: LR scheduler constructor keyword args. :param float regularizer: the coefficient for L2 regularizer term. :param type dtype: the data type to use for the model. Valid option are `torch.float32` and `torch.float64` (`torch.float16` only on GPU); @@ -67,8 +74,13 @@ class PINN(object): self.input_pts = {} self.trained_epoch = 0 + + if not optimizer_kwargs: + optimizer_kwargs = {} + optimizer_kwargs['lr'] = lr self.optimizer = optimizer( - self.model.parameters(), lr=lr, weight_decay=regularizer) + self.model.parameters(), weight_decay=regularizer, **optimizer_kwargs) + self._lr_scheduler = lr_scheduler_type(self.optimizer, **lr_scheduler_kwargs) self.batch_size = batch_size self.data_set = PinaDataset(self) @@ -268,6 +280,8 @@ class PINN(object): losses.append(sum(single_loss)) + self._lr_scheduler.step() + if save_loss and (epoch % save_loss == 0 or epoch == 0): self.history_loss[epoch] = [ loss.detach().item() for loss in losses] diff --git a/tests/test_pinn.py b/tests/test_pinn.py index 7c0cd0b..5f8d7c7 100644 --- a/tests/test_pinn.py +++ b/tests/test_pinn.py @@ -118,6 +118,37 @@ def test_train(): assert list(pinn.history_loss.keys()) == truth_key +def test_train_with_optimizer_kwargs(): + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + expected_keys = [[], list(range(0, 50, 3))] + param = [0, 3] + for i, truth_key in zip(param, expected_keys): + pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3}) + pinn.span_pts(n, 'grid', boundaries) + pinn.span_pts(n, 'grid', locations=['D']) + pinn.train(50, save_loss=i) + assert list(pinn.history_loss.keys()) == truth_key + + +def test_train_with_lr_scheduler(): + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + expected_keys = [[], list(range(0, 50, 3))] + param = [0, 3] + for i, truth_key in zip(param, expected_keys): + pinn = PINN( + problem, + model, + lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, + lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} + ) + pinn.span_pts(n, 'grid', boundaries) + pinn.span_pts(n, 'grid', locations=['D']) + pinn.train(50, save_loss=i) + assert list(pinn.history_loss.keys()) == truth_key + + def test_train_batch(): pinn = PINN(problem, model, batch_size=6) boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']