From bb1efe44bca7baaa4b7c025a8637f1347f3c08c5 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Tue, 8 Nov 2022 12:11:58 +0100 Subject: [PATCH] full compatibility with torch models * Network class added * adding tests for Network class --- pina/model/__init__.py | 2 + pina/model/network.py | 104 +++++++++++++++++++++++++++++++++++++++++ tests/test_network.py | 83 ++++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+) create mode 100644 pina/model/network.py create mode 100644 tests/test_network.py diff --git a/pina/model/__init__.py b/pina/model/__init__.py index a2d715a..ccdc8a8 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -2,8 +2,10 @@ __all__ = [ 'FeedForward', 'MultiFeedForward' 'DeepONet', + 'Network' ] from .feed_forward import FeedForward from .multi_feed_forward import MultiFeedForward from .deeponet import DeepONet +from .network import Network diff --git a/pina/model/network.py b/pina/model/network.py new file mode 100644 index 0000000..b661a05 --- /dev/null +++ b/pina/model/network.py @@ -0,0 +1,104 @@ +import torch +from pina.label_tensor import LabelTensor + + +class Network(torch.nn.Module): + """The PINA implementation of any neural network. + + :param torch.nn.Module model: the torch model of the network. + :param list(str) input_variables: the list containing the labels + corresponding to the input components of the model. + :param list(str) output_variables: the list containing the labels + corresponding to the components of the output computed by the model. + :param torch.nn.Module extra_features: the additional input + features to use as augmented input. + + :Example: + >>> class SimpleNet(nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.layers = nn.Sequential( + >>> nn.Linear(3, 20), + >>> nn.Tanh(), + >>> nn.Linear(20, 1) + >>> ) + >>> def forward(self, x): + >>> return self.layers(x) + >>> net = SimpleNet() + >>> input_variables = ['x', 'y'] + >>> output_variables =['u'] + >>> model_feat = Network(net, input_variables, output_variables) + Network( + (extra_features): Sequential() + (model): Sequential( + (0): Linear(in_features=2, out_features=20, bias=True) + (1): Tanh() + (2): Linear(in_features=20, out_features=1, bias=True) + ) + ) + """ + + def __init__(self, model, input_variables, + output_variables, extra_features=None): + super().__init__() + + if extra_features is None: + extra_features = [] + + self._extra_features = torch.nn.Sequential(*extra_features) + self._model = model + self._input_variables = input_variables + self._output_variables = output_variables + + # check model and input/output + self._check_consistency() + + def _check_consistency(self): + """Checking the consistency of model with input and output variables + + :raises ValueError: Error in constructing the PINA network + """ + try: + tmp = torch.rand((10, len(self._input_variables))) + tmp = LabelTensor(tmp, self._input_variables) + tmp = self.forward(tmp) # trying a forward pass + tmp = LabelTensor(tmp, self._output_variables) + except: + raise ValueError('Error in constructing the PINA network.' + ' Check compatibility of input/output' + ' variables shape with the torch model' + ' or check the correctness of the torch' + ' model itself.') + + def forward(self, x): + """Forward method for Network class + + :param torch.tensor x: input of the network + :return torch.tensor: output of the network + """ + + x = x.extract(self._input_variables) + + for feature in self._extra_features: + x = x.append(feature(x)) + + output = self._model(x).as_subclass(LabelTensor) + output.labels = self._output_variables + + return output + + @property + def input_variables(self): + return self._input_variables + + @property + def output_variables(self): + return self._output_variables + + @property + def extra_features(self): + return self._extra_features + + @property + def model(self): + return self._model diff --git a/tests/test_network.py b/tests/test_network.py new file mode 100644 index 0000000..691197a --- /dev/null +++ b/tests/test_network.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import pytest +from pina.model import Network +from pina import LabelTensor + + +class SimpleNet(nn.Module): + + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(2, 20), + nn.Tanh(), + nn.Linear(20, 1) + ) + + def forward(self, x): + return self.layers(x) + + +class SimpleNetExtraFeat(nn.Module): + + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(3, 20), + nn.Tanh(), + nn.Linear(20, 1) + ) + + def forward(self, x): + return self.layers(x) + + +class myFeature(torch.nn.Module): + """ + Feature: sin(x) + """ + + def __init__(self): + super(myFeature, self).__init__() + + def forward(self, x): + t = (torch.sin(x.extract(['x'])*torch.pi) * + torch.sin(x.extract(['y'])*torch.pi)) + return LabelTensor(t, ['sin(x)sin(y)']) + + +input_variables = ['x', 'y'] +output_variables = ['u'] +data = torch.rand((20, 2)) +input_ = LabelTensor(data, input_variables) + + +def test_constructor(): + net = SimpleNet() + pina_net = Network(model=net, input_variables=input_variables, + output_variables=output_variables) + + +def test_forward(): + net = SimpleNet() + pina_net = Network(model=net, input_variables=input_variables, + output_variables=output_variables) + output_ = pina_net(input_) + assert output_.labels == output_variables + + +def test_constructor_extrafeat(): + net = SimpleNetExtraFeat() + feat = [myFeature()] + pina_net = Network(model=net, input_variables=input_variables, + output_variables=output_variables, extra_features=feat) + + +def test_forward_extrafeat(): + net = SimpleNetExtraFeat() + feat = [myFeature()] + pina_net = Network(model=net, input_variables=input_variables, + output_variables=output_variables, extra_features=feat) + output_ = pina_net(input_) + assert output_.labels == output_variables