fnn update, pinn torch models, tests update. (#88)
* fnn update, remove labeltensors * allow custom torch models * updating tests --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Dario Coscia <dariocoscia@dhcp-031.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
c8fb7715c4
commit
be11110bb2
11
pina/pinn.py
11
pina/pinn.py
@@ -3,6 +3,7 @@ import torch
|
||||
import torch.optim.lr_scheduler as lrs
|
||||
|
||||
from .problem import AbstractProblem
|
||||
from .model import Network
|
||||
from .label_tensor import LabelTensor
|
||||
from .utils import merge_tensors, PinaDataset
|
||||
|
||||
@@ -15,6 +16,7 @@ class PINN(object):
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs=None,
|
||||
lr=0.001,
|
||||
@@ -28,6 +30,8 @@ class PINN(object):
|
||||
'''
|
||||
:param AbstractProblem problem: the formualation of the problem.
|
||||
:param torch.nn.Module model: the neural network model to use.
|
||||
:param torch.nn.Module extra_features: the additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: the neural network optimizer to
|
||||
use; default is `torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
@@ -68,7 +72,12 @@ class PINN(object):
|
||||
self.dtype = dtype
|
||||
self.history_loss = {}
|
||||
|
||||
self.model = model
|
||||
|
||||
self.model = Network(model=model,
|
||||
input_variables=problem.input_variables,
|
||||
output_variables=problem.output_variables,
|
||||
extra_features=extra_features)
|
||||
|
||||
self.model.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
self.truth_values = {}
|
||||
|
||||
Reference in New Issue
Block a user