enable lr_scheduler usage

This commit is contained in:
Francesco Andreuzzi
2022-12-25 17:54:12 +01:00
committed by Nicola Demo
parent 18c375887c
commit 5c09ff626c
2 changed files with 46 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
""" Module for PINN """
import torch
import torch.optim.lr_scheduler as lrs
from .problem import AbstractProblem
from .label_tensor import LabelTensor
@@ -15,7 +16,10 @@ class PINN(object):
problem,
model,
optimizer=torch.optim.Adam,
optimizer_kwargs=None,
lr=0.001,
lr_scheduler_type=lrs.ConstantLR,
lr_scheduler_kwargs={"factor" : 1, "total_iters" : 0},
regularizer=0.00001,
batch_size=None,
dtype=torch.float32,
@@ -26,7 +30,10 @@ class PINN(object):
:param torch.nn.Module model: the neural network model to use.
:param torch.optim optimizer: the neural network optimizer to use;
default is `torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param float lr: the learning rate; default is 0.001.
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler_type: Learning rate scheduler.
:param dict lr_scheduler_kwargs: LR scheduler constructor keyword args.
:param float regularizer: the coefficient for L2 regularizer term.
:param type dtype: the data type to use for the model. Valid option are
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
@@ -67,8 +74,13 @@ class PINN(object):
self.input_pts = {}
self.trained_epoch = 0
if not optimizer_kwargs:
optimizer_kwargs = {}
optimizer_kwargs['lr'] = lr
self.optimizer = optimizer(
self.model.parameters(), lr=lr, weight_decay=regularizer)
self.model.parameters(), weight_decay=regularizer, **optimizer_kwargs)
self._lr_scheduler = lr_scheduler_type(self.optimizer, **lr_scheduler_kwargs)
self.batch_size = batch_size
self.data_set = PinaDataset(self)
@@ -268,6 +280,8 @@ class PINN(object):
losses.append(sum(single_loss))
self._lr_scheduler.step()
if save_loss and (epoch % save_loss == 0 or epoch == 0):
self.history_loss[epoch] = [
loss.detach().item() for loss in losses]