diff --git a/pina/pinn.py b/pina/pinn.py index 62c647b..86df135 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -1,7 +1,11 @@ """ Module for PINN """ import torch -import torch.optim.lr_scheduler as lrs +try: + from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0 +from torch.optim.lr_scheduler import ConstantLR from .solver import SolverInterface from .label_tensor import LabelTensor @@ -23,7 +27,7 @@ class PINN(SolverInterface): loss = torch.nn.MSELoss(), optimizer=torch.optim.Adam, optimizer_kwargs={'lr' : 0.001}, - scheduler=lrs.ConstantLR, + scheduler=ConstantLR, scheduler_kwargs={"factor": 1, "total_iters": 0}, ): ''' @@ -46,7 +50,7 @@ class PINN(SolverInterface): # check consistency check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True) check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs') - check_consistency(scheduler, lrs.LRScheduler, 'scheduler', subclass=True) + check_consistency(scheduler, LRScheduler, 'scheduler', subclass=True) check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs') check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False) @@ -116,4 +120,4 @@ class PINN(SolverInterface): # TODO Fix the bug, tot_loss is a label tensor without labels # we need to pass it as a torch tensor to make everything work total_loss = sum(condition_losses) - return total_loss \ No newline at end of file + return total_loss diff --git a/setup.py b/setup.py index 36038b0..ab9e1a6 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ VERSION = meta['__version__'] KEYWORDS = 'physics-informed neural-network' REQUIRED = [ - 'numpy', 'matplotlib', 'torch' + 'numpy', 'matplotlib', 'torch', 'lightning' ] EXTRAS = { diff --git a/tests/test_pinn.py b/tests/test_pinn.py index e04aa58..4e533f6 100644 --- a/tests/test_pinn.py +++ b/tests/test_pinn.py @@ -86,24 +86,24 @@ def test_constructor_extra_feats(): model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables)) PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats) -def test_train(): +def test_train_cpu(): poisson_problem = Poisson() boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 10 poisson_problem.discretise_domain(n, 'grid', locations=boundaries) poisson_problem.discretise_domain(n, 'grid', locations=['D']) pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) - trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5}) + trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'}) trainer.train() -def test_train_extra_feats(): +def test_train_extra_feats_cpu(): poisson_problem = Poisson() boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 10 poisson_problem.discretise_domain(n, 'grid', locations=boundaries) poisson_problem.discretise_domain(n, 'grid', locations=['D']) pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats) - trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5}) + trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'}) trainer.train() """