check train test (#111)

* Update setup.py
* scheduler for all torch versions
---------

Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-0-208.WIFIeduroamSTUD.units.it>
Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
Dario Coscia
2023-06-13 12:07:04 +02:00
committed by Nicola Demo
parent 8ce3405edc
commit 0fb93a73ab
3 changed files with 13 additions and 9 deletions

View File

@@ -1,7 +1,11 @@
""" Module for PINN """ """ Module for PINN """
import torch import torch
import torch.optim.lr_scheduler as lrs try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from .solver import SolverInterface from .solver import SolverInterface
from .label_tensor import LabelTensor from .label_tensor import LabelTensor
@@ -23,7 +27,7 @@ class PINN(SolverInterface):
loss = torch.nn.MSELoss(), loss = torch.nn.MSELoss(),
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
optimizer_kwargs={'lr' : 0.001}, optimizer_kwargs={'lr' : 0.001},
scheduler=lrs.ConstantLR, scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0}, scheduler_kwargs={"factor": 1, "total_iters": 0},
): ):
''' '''
@@ -46,7 +50,7 @@ class PINN(SolverInterface):
# check consistency # check consistency
check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True) check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True)
check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs') check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs')
check_consistency(scheduler, lrs.LRScheduler, 'scheduler', subclass=True) check_consistency(scheduler, LRScheduler, 'scheduler', subclass=True)
check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs') check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs')
check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False) check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False)

View File

@@ -15,7 +15,7 @@ VERSION = meta['__version__']
KEYWORDS = 'physics-informed neural-network' KEYWORDS = 'physics-informed neural-network'
REQUIRED = [ REQUIRED = [
'numpy', 'matplotlib', 'torch' 'numpy', 'matplotlib', 'torch', 'lightning'
] ]
EXTRAS = { EXTRAS = {

View File

@@ -86,24 +86,24 @@ def test_constructor_extra_feats():
model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables)) model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables))
PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats) PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
def test_train(): def test_train_cpu():
poisson_problem = Poisson() poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10 n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries) poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
poisson_problem.discretise_domain(n, 'grid', locations=['D']) poisson_problem.discretise_domain(n, 'grid', locations=['D'])
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5}) trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
trainer.train() trainer.train()
def test_train_extra_feats(): def test_train_extra_feats_cpu():
poisson_problem = Poisson() poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10 n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries) poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
poisson_problem.discretise_domain(n, 'grid', locations=['D']) poisson_problem.discretise_domain(n, 'grid', locations=['D'])
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats) pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5}) trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
trainer.train() trainer.train()
""" """