check train test (#111)
* Update setup.py * scheduler for all torch versions --------- Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-0-208.WIFIeduroamSTUD.units.it> Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
8ce3405edc
commit
0fb93a73ab
12
pina/pinn.py
12
pina/pinn.py
@@ -1,7 +1,11 @@
|
||||
""" Module for PINN """
|
||||
import torch
|
||||
import torch.optim.lr_scheduler as lrs
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .solver import SolverInterface
|
||||
from .label_tensor import LabelTensor
|
||||
@@ -23,7 +27,7 @@ class PINN(SolverInterface):
|
||||
loss = torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={'lr' : 0.001},
|
||||
scheduler=lrs.ConstantLR,
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
'''
|
||||
@@ -46,7 +50,7 @@ class PINN(SolverInterface):
|
||||
# check consistency
|
||||
check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True)
|
||||
check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs')
|
||||
check_consistency(scheduler, lrs.LRScheduler, 'scheduler', subclass=True)
|
||||
check_consistency(scheduler, LRScheduler, 'scheduler', subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs')
|
||||
check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False)
|
||||
|
||||
@@ -116,4 +120,4 @@ class PINN(SolverInterface):
|
||||
# TODO Fix the bug, tot_loss is a label tensor without labels
|
||||
# we need to pass it as a torch tensor to make everything work
|
||||
total_loss = sum(condition_losses)
|
||||
return total_loss
|
||||
return total_loss
|
||||
|
||||
Reference in New Issue
Block a user