fnn update, pinn torch models, tests update. (#88)

* fnn update, remove labeltensors
* allow custom torch models
* updating tests

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Dario Coscia <dariocoscia@dhcp-031.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-05-02 15:19:48 +02:00
committed by Nicola Demo
parent c8fb7715c4
commit be11110bb2
11 changed files with 149 additions and 177 deletions

View File

@@ -3,6 +3,7 @@ import torch
import torch.optim.lr_scheduler as lrs
from .problem import AbstractProblem
from .model import Network
from .label_tensor import LabelTensor
from .utils import merge_tensors, PinaDataset
@@ -15,6 +16,7 @@ class PINN(object):
def __init__(self,
problem,
model,
extra_features=None,
optimizer=torch.optim.Adam,
optimizer_kwargs=None,
lr=0.001,
@@ -28,6 +30,8 @@ class PINN(object):
'''
:param AbstractProblem problem: the formualation of the problem.
:param torch.nn.Module model: the neural network model to use.
:param torch.nn.Module extra_features: the additional input
features to use as augmented input.
:param torch.optim.Optimizer optimizer: the neural network optimizer to
use; default is `torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
@@ -68,7 +72,12 @@ class PINN(object):
self.dtype = dtype
self.history_loss = {}
self.model = model
self.model = Network(model=model,
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features)
self.model.to(dtype=self.dtype, device=self.device)
self.truth_values = {}