check train test (#111)
* Update setup.py * scheduler for all torch versions --------- Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-0-208.WIFIeduroamSTUD.units.it> Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
8ce3405edc
commit
0fb93a73ab
10
pina/pinn.py
10
pina/pinn.py
@@ -1,7 +1,11 @@
|
|||||||
""" Module for PINN """
|
""" Module for PINN """
|
||||||
import torch
|
import torch
|
||||||
import torch.optim.lr_scheduler as lrs
|
try:
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||||
|
except ImportError:
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import ConstantLR
|
||||||
|
|
||||||
from .solver import SolverInterface
|
from .solver import SolverInterface
|
||||||
from .label_tensor import LabelTensor
|
from .label_tensor import LabelTensor
|
||||||
@@ -23,7 +27,7 @@ class PINN(SolverInterface):
|
|||||||
loss = torch.nn.MSELoss(),
|
loss = torch.nn.MSELoss(),
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
optimizer_kwargs={'lr' : 0.001},
|
optimizer_kwargs={'lr' : 0.001},
|
||||||
scheduler=lrs.ConstantLR,
|
scheduler=ConstantLR,
|
||||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
@@ -46,7 +50,7 @@ class PINN(SolverInterface):
|
|||||||
# check consistency
|
# check consistency
|
||||||
check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True)
|
check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True)
|
||||||
check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs')
|
check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs')
|
||||||
check_consistency(scheduler, lrs.LRScheduler, 'scheduler', subclass=True)
|
check_consistency(scheduler, LRScheduler, 'scheduler', subclass=True)
|
||||||
check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs')
|
check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs')
|
||||||
check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False)
|
check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False)
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -15,7 +15,7 @@ VERSION = meta['__version__']
|
|||||||
KEYWORDS = 'physics-informed neural-network'
|
KEYWORDS = 'physics-informed neural-network'
|
||||||
|
|
||||||
REQUIRED = [
|
REQUIRED = [
|
||||||
'numpy', 'matplotlib', 'torch'
|
'numpy', 'matplotlib', 'torch', 'lightning'
|
||||||
]
|
]
|
||||||
|
|
||||||
EXTRAS = {
|
EXTRAS = {
|
||||||
|
|||||||
@@ -86,24 +86,24 @@ def test_constructor_extra_feats():
|
|||||||
model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables))
|
model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables))
|
||||||
PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
||||||
|
|
||||||
def test_train():
|
def test_train_cpu():
|
||||||
poisson_problem = Poisson()
|
poisson_problem = Poisson()
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
n = 10
|
n = 10
|
||||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
||||||
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5})
|
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
def test_train_extra_feats():
|
def test_train_extra_feats_cpu():
|
||||||
poisson_problem = Poisson()
|
poisson_problem = Poisson()
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
n = 10
|
n = 10
|
||||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
||||||
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
||||||
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5})
|
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user