enable lr_scheduler usage
This commit is contained in:
committed by
Nicola Demo
parent
18c375887c
commit
5c09ff626c
16
pina/pinn.py
16
pina/pinn.py
@@ -1,5 +1,6 @@
|
||||
""" Module for PINN """
|
||||
import torch
|
||||
import torch.optim.lr_scheduler as lrs
|
||||
|
||||
from .problem import AbstractProblem
|
||||
from .label_tensor import LabelTensor
|
||||
@@ -15,7 +16,10 @@ class PINN(object):
|
||||
problem,
|
||||
model,
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs=None,
|
||||
lr=0.001,
|
||||
lr_scheduler_type=lrs.ConstantLR,
|
||||
lr_scheduler_kwargs={"factor" : 1, "total_iters" : 0},
|
||||
regularizer=0.00001,
|
||||
batch_size=None,
|
||||
dtype=torch.float32,
|
||||
@@ -26,7 +30,10 @@ class PINN(object):
|
||||
:param torch.nn.Module model: the neural network model to use.
|
||||
:param torch.optim optimizer: the neural network optimizer to use;
|
||||
default is `torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param float lr: the learning rate; default is 0.001.
|
||||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler_type: Learning rate scheduler.
|
||||
:param dict lr_scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
:param float regularizer: the coefficient for L2 regularizer term.
|
||||
:param type dtype: the data type to use for the model. Valid option are
|
||||
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
|
||||
@@ -67,8 +74,13 @@ class PINN(object):
|
||||
self.input_pts = {}
|
||||
|
||||
self.trained_epoch = 0
|
||||
|
||||
if not optimizer_kwargs:
|
||||
optimizer_kwargs = {}
|
||||
optimizer_kwargs['lr'] = lr
|
||||
self.optimizer = optimizer(
|
||||
self.model.parameters(), lr=lr, weight_decay=regularizer)
|
||||
self.model.parameters(), weight_decay=regularizer, **optimizer_kwargs)
|
||||
self._lr_scheduler = lr_scheduler_type(self.optimizer, **lr_scheduler_kwargs)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.data_set = PinaDataset(self)
|
||||
@@ -268,6 +280,8 @@ class PINN(object):
|
||||
|
||||
losses.append(sum(single_loss))
|
||||
|
||||
self._lr_scheduler.step()
|
||||
|
||||
if save_loss and (epoch % save_loss == 0 or epoch == 0):
|
||||
self.history_loss[epoch] = [
|
||||
loss.detach().item() for loss in losses]
|
||||
|
||||
@@ -118,6 +118,37 @@ def test_train():
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
def test_train_with_optimizer_kwargs():
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
expected_keys = [[], list(range(0, 50, 3))]
|
||||
param = [0, 3]
|
||||
for i, truth_key in zip(param, expected_keys):
|
||||
pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3})
|
||||
pinn.span_pts(n, 'grid', boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(50, save_loss=i)
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
def test_train_with_lr_scheduler():
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
expected_keys = [[], list(range(0, 50, 3))]
|
||||
param = [0, 3]
|
||||
for i, truth_key in zip(param, expected_keys):
|
||||
pinn = PINN(
|
||||
problem,
|
||||
model,
|
||||
lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
|
||||
lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
|
||||
)
|
||||
pinn.span_pts(n, 'grid', boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(50, save_loss=i)
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
def test_train_batch():
|
||||
pinn = PINN(problem, model, batch_size=6)
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
|
||||
Reference in New Issue
Block a user