Lightining update (#104)

* multiple functions for version 0.0
* lightining update
* minor changes
* data pinn  loss added
---------

Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-3-125.WIFIeduroamSTUD.units.it>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.station>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Dario Coscia <dariocoscia@192.168.1.38>
This commit is contained in:
Dario Coscia
2023-06-07 15:34:43 +02:00
committed by Nicola Demo
parent 0e3625de80
commit 63fd068988
16 changed files with 710 additions and 603 deletions

View File

@@ -1,107 +1,47 @@
import torch
from pina.label_tensor import LabelTensor
import torch.nn as nn
from ..utils import check_consistency
class Network(torch.nn.Module):
"""The PINA implementation of any neural network.
:param torch.nn.Module model: the torch model of the network.
:param list(str) input_variables: the list containing the labels
corresponding to the input components of the model.
:param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the model.
:param torch.nn.Module extra_features: the additional input
features to use as augmented input.
:Example:
>>> class SimpleNet(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.layers = nn.Sequential(
>>> nn.Linear(2, 20),
>>> nn.Tanh(),
>>> nn.Linear(20, 1)
>>> )
>>> def forward(self, x):
>>> return self.layers(x)
>>> net = SimpleNet()
>>> input_variables = ['x', 'y']
>>> output_variables =['u']
>>> model_feat = Network(net, input_variables, output_variables)
Network(
(extra_features): Sequential()
(model): Sequential(
(0): Linear(in_features=2, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=1, bias=True)
)
)
"""
def __init__(self, model, input_variables,
output_variables, extra_features=None):
def __init__(self, model, extra_features=None):
super().__init__()
print('HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH')
if extra_features is None:
extra_features = []
self._extra_features = torch.nn.Sequential(*extra_features)
# check model consistency
check_consistency(model, nn.Module, 'torch model')
self._model = model
self._input_variables = input_variables
self._output_variables = output_variables
print(output_variables)
# check model and input/output
self._check_consistency()
# check consistency and assign extra fatures
if extra_features is None:
self._extra_features = []
else:
for feat in extra_features:
check_consistency(feat, nn.Module, 'extra features')
self._extra_features = nn.Sequential(*extra_features)
def _check_consistency(self):
"""Checking the consistency of model with input and output variables
:raises ValueError: Error in constructing the PINA network
"""
try:
pass
# tmp = torch.rand((10, len(self._input_variables)))
# tmp = LabelTensor(tmp, self._input_variables)
# tmp = self.forward(tmp) # trying a forward pass
# tmp = LabelTensor(tmp, self._output_variables)
except:
raise ValueError('Error in constructing the PINA network.'
' Check compatibility of input/output'
' variables shape with the torch model'
' or check the correctness of the torch'
' model itself.')
# check model works with inputs
# TODO
def forward(self, x):
"""Forward method for Network class
"""
Forward method for Network class. This class
implements the standard forward method, and
it adds the possibility to pass extra features.
:param torch.tensor x: input of the network
:return torch.tensor: output of the network
"""
x = x.extract(self._input_variables)
# extract features and append
for feature in self._extra_features:
x = x.append(feature(x))
output = self._model(x).as_subclass(LabelTensor)
output.labels = self._output_variables
return output
@property
def input_variables(self):
return self._input_variables
@property
def output_variables(self):
return self._output_variables
@property
def extra_features(self):
return self._extra_features
# perform forward pass
return self._model(x)
@property
def model(self):
return self._model
@property
def extra_features(self):
return self._extra_features