New Residual Model and Fix relative import
* Adding Residual MLP * Adding test Residual MLP * Modified relative import Continuous Conv
This commit is contained in:
committed by
Nicola Demo
parent
ba7371f350
commit
17464ceca9
@@ -1,12 +1,13 @@
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'FeedForward',
|
'FeedForward',
|
||||||
|
'ResidualFeedForward',
|
||||||
'MultiFeedForward',
|
'MultiFeedForward',
|
||||||
'DeepONet',
|
'DeepONet',
|
||||||
'MIONet',
|
'MIONet',
|
||||||
'FNO',
|
'FNO',
|
||||||
]
|
]
|
||||||
|
|
||||||
from .feed_forward import FeedForward
|
from .feed_forward import FeedForward, ResidualFeedForward
|
||||||
from .multi_feed_forward import MultiFeedForward
|
from .multi_feed_forward import MultiFeedForward
|
||||||
from .deeponet import DeepONet, MIONet
|
from .deeponet import DeepONet, MIONet
|
||||||
from .fno import FNO
|
from .fno import FNO
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Module for FeedForward model"""
|
"""Module for FeedForward model"""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from ..utils import check_consistency
|
||||||
|
from .layers.residual import EnhancedLinear
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(torch.nn.Module):
|
class FeedForward(torch.nn.Module):
|
||||||
@@ -8,8 +10,8 @@ class FeedForward(torch.nn.Module):
|
|||||||
The PINA implementation of feedforward network, also refered as multilayer
|
The PINA implementation of feedforward network, also refered as multilayer
|
||||||
perceptron.
|
perceptron.
|
||||||
|
|
||||||
:param int input_dimensons: The number of input components of the model.
|
:param int input_dimensions: The number of input components of the model.
|
||||||
Expected tensor shape of the form (*, input_dimensons), where *
|
Expected tensor shape of the form (*, input_dimensions), where *
|
||||||
means any number of dimensions including none.
|
means any number of dimensions including none.
|
||||||
:param int output_dimensions: The number of output components of the model.
|
:param int output_dimensions: The number of output components of the model.
|
||||||
Expected tensor shape of the form (*, output_dimensions), where *
|
Expected tensor shape of the form (*, output_dimensions), where *
|
||||||
@@ -80,3 +82,130 @@ class FeedForward(torch.nn.Module):
|
|||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualFeedForward(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
The PINA implementation of feedforward network, also with skipped connection
|
||||||
|
and transformer network, as presented in **Understanding and mitigating gradient
|
||||||
|
pathologies in physics-informed neural networks**
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Wang, Sifan, Yujun Teng, and Paris Perdikaris.
|
||||||
|
"Understanding and mitigating gradient flow pathologies in physics-informed
|
||||||
|
neural networks." SIAM Journal on Scientific Computing 43.5 (2021): A3055-A3081.
|
||||||
|
DOI: `10.1137/20M1318043
|
||||||
|
<https://epubs.siam.org/doi/abs/10.1137/20M1318043>`_
|
||||||
|
|
||||||
|
|
||||||
|
:param int input_dimensions: The number of input components of the model.
|
||||||
|
Expected tensor shape of the form (*, input_dimensions), where *
|
||||||
|
means any number of dimensions including none.
|
||||||
|
:param int output_dimensions: The number of output components of the model.
|
||||||
|
Expected tensor shape of the form (*, output_dimensions), where *
|
||||||
|
means any number of dimensions including none.
|
||||||
|
:param int inner_size: number of neurons in the hidden layer(s). Default is
|
||||||
|
20.
|
||||||
|
:param int n_layers: number of hidden layers. Default is 2.
|
||||||
|
:param func: the activation function to use. If a single
|
||||||
|
:class:`torch.nn.Module` is passed, this is used as activation function
|
||||||
|
after any layers, except the last one. If a list of Modules is passed,
|
||||||
|
they are used as activation functions at any layers, in order.
|
||||||
|
:param bool bias: If `True` the MLP will consider some bias.
|
||||||
|
:param list | tuple transformer_nets: a list or tuple containing the two
|
||||||
|
torch.nn.Module which act as transformer network. The input dimension
|
||||||
|
of the network must be the same as ``input_dimensions``, and the output
|
||||||
|
dimension must be the same as ``inner_size``.
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dimensions, output_dimensions, inner_size=20,
|
||||||
|
n_layers=2, func=nn.Tanh, bias=True, transformer_nets=None):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# check type consistency
|
||||||
|
check_consistency(input_dimensions, int)
|
||||||
|
check_consistency(output_dimensions, int)
|
||||||
|
check_consistency(inner_size, int)
|
||||||
|
check_consistency(n_layers, int)
|
||||||
|
check_consistency(func, torch.nn.Module, subclass=True)
|
||||||
|
check_consistency(bias, bool)
|
||||||
|
|
||||||
|
# check transformer nets
|
||||||
|
if transformer_nets is None:
|
||||||
|
transformer_nets = [
|
||||||
|
EnhancedLinear(nn.Linear(in_features=input_dimensions, out_features=inner_size),
|
||||||
|
nn.Tanh()),
|
||||||
|
EnhancedLinear(nn.Linear(in_features=input_dimensions, out_features=inner_size),
|
||||||
|
nn.Tanh())
|
||||||
|
]
|
||||||
|
elif isinstance(transformer_nets, (list, tuple)):
|
||||||
|
if len(transformer_nets) != 2:
|
||||||
|
raise ValueError('transformer_nets needs to be a list of len two.')
|
||||||
|
for net in transformer_nets:
|
||||||
|
if not isinstance(net, nn.Module):
|
||||||
|
raise ValueError('transformer_nets needs to be a list of torch.nn.Module.')
|
||||||
|
x = torch.rand(10, input_dimensions)
|
||||||
|
try:
|
||||||
|
out = net(x)
|
||||||
|
except RuntimeError:
|
||||||
|
raise ValueError('transformer network input incompatible with input_dimensions.')
|
||||||
|
if out.shape[-1] != inner_size:
|
||||||
|
raise ValueError('transformer network output incompatible with inner_size.')
|
||||||
|
else:
|
||||||
|
RuntimeError('Runtime error for transformer nets, check official documentation.')
|
||||||
|
|
||||||
|
# assign variables
|
||||||
|
self.input_dimension = input_dimensions
|
||||||
|
self.output_dimension = output_dimensions
|
||||||
|
self.transformer_nets = nn.ModuleList(transformer_nets)
|
||||||
|
|
||||||
|
# build layers
|
||||||
|
layers = [inner_size] * n_layers
|
||||||
|
|
||||||
|
tmp_layers = layers.copy()
|
||||||
|
tmp_layers.insert(0, self.input_dimension)
|
||||||
|
|
||||||
|
self.layers = []
|
||||||
|
for i in range(len(tmp_layers) - 1):
|
||||||
|
self.layers.append(
|
||||||
|
nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias)
|
||||||
|
)
|
||||||
|
self.last_layer = nn.Linear(tmp_layers[len(tmp_layers) - 1], output_dimensions, bias=bias)
|
||||||
|
|
||||||
|
if isinstance(func, list):
|
||||||
|
self.functions = func()
|
||||||
|
else:
|
||||||
|
self.functions = [func() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
|
if len(self.layers) != len(self.functions):
|
||||||
|
raise RuntimeError('uncosistent number of layers and functions')
|
||||||
|
|
||||||
|
unique_list = []
|
||||||
|
for layer, func in zip(self.layers, self.functions):
|
||||||
|
unique_list.append(EnhancedLinear(layer=layer,
|
||||||
|
activation=func))
|
||||||
|
self.inner_layers = torch.nn.Sequential(*unique_list)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Defines the computation performed at every call.
|
||||||
|
|
||||||
|
:param x: .
|
||||||
|
:type x: :class:`pina.LabelTensor`
|
||||||
|
:return: the output computed by the model.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# enhance the input with transformer
|
||||||
|
input_ = []
|
||||||
|
for nets in self.transformer_nets:
|
||||||
|
input_.append(nets(x))
|
||||||
|
|
||||||
|
# skip connections pass
|
||||||
|
for layer in self.inner_layers.children():
|
||||||
|
x = layer(x)
|
||||||
|
x = (1. - x) * input_[0] + x * input_[1]
|
||||||
|
|
||||||
|
# last layer
|
||||||
|
return self.last_layer(x)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'ContinuousConvBlock',
|
'ContinuousConvBlock',
|
||||||
'ResidualBlock',
|
'ResidualBlock',
|
||||||
|
'EnhancedLinear',
|
||||||
'SpectralConvBlock1D',
|
'SpectralConvBlock1D',
|
||||||
'SpectralConvBlock2D',
|
'SpectralConvBlock2D',
|
||||||
'SpectralConvBlock3D',
|
'SpectralConvBlock3D',
|
||||||
@@ -10,6 +11,6 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
from .convolution_2d import ContinuousConvBlock
|
from .convolution_2d import ContinuousConvBlock
|
||||||
from .residual import ResidualBlock
|
from .residual import ResidualBlock, EnhancedLinear
|
||||||
from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
|
from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
|
||||||
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||||
|
|||||||
@@ -113,6 +113,21 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
|||||||
else:
|
else:
|
||||||
self.transpose = self.transpose_overlap
|
self.transpose = self.transpose_overlap
|
||||||
|
|
||||||
|
class DefaultKernel(torch.nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(input_dim, int)
|
||||||
|
assert isinstance(output_dim, int)
|
||||||
|
self._model = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(input_dim, 20),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(20, 20),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(20, output_dim)
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return self._model(x)
|
||||||
|
|
||||||
@ property
|
@ property
|
||||||
def net(self):
|
def net(self):
|
||||||
return self._net
|
return self._net
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
from .convolution import BaseContinuousConv
|
from .convolution import BaseContinuousConv
|
||||||
from .utils_convolution import check_point, map_points_
|
from .utils_convolution import check_point, map_points_
|
||||||
from .integral import Integral
|
from .integral import Integral
|
||||||
from ..feed_forward import FeedForward
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -34,8 +33,8 @@ class ContinuousConvBlock(BaseContinuousConv):
|
|||||||
:param stride: Stride for the filter.
|
:param stride: Stride for the filter.
|
||||||
:type stride: dict
|
:type stride: dict
|
||||||
:param model: Neural network for inner parametrization,
|
:param model: Neural network for inner parametrization,
|
||||||
defaults to None. If None, pina.FeedForward is used, more
|
defaults to None. If None, a default multilayer perceptron
|
||||||
on https://mathlab.github.io/PINA/_rst/fnn.html.
|
is used, see BaseContinuousConv.DefaultKernel.
|
||||||
:type model: torch.nn.Module, optional
|
:type model: torch.nn.Module, optional
|
||||||
:param optimize: Flag for performing optimization on the continuous
|
:param optimize: Flag for performing optimization on the continuous
|
||||||
filter, defaults to False. The flag `optimize=True` should be
|
filter, defaults to False. The flag `optimize=True` should be
|
||||||
@@ -152,7 +151,7 @@ class ContinuousConvBlock(BaseContinuousConv):
|
|||||||
nets = []
|
nets = []
|
||||||
if self._net is None:
|
if self._net is None:
|
||||||
for _ in range(self._input_numb_field * self._output_numb_field):
|
for _ in range(self._input_numb_field * self._output_numb_field):
|
||||||
tmp = FeedForward(len(self._dim), 1)
|
tmp = ContinuousConvBlock.DefaultKernel(len(self._dim), 1)
|
||||||
nets.append(tmp)
|
nets.append(tmp)
|
||||||
else:
|
else:
|
||||||
if not isinstance(model, object):
|
if not isinstance(model, object):
|
||||||
|
|||||||
@@ -93,3 +93,38 @@ class ResidualBlock(nn.Module):
|
|||||||
@ property
|
@ property
|
||||||
def activation(self):
|
def activation(self):
|
||||||
return self._activation
|
return self._activation
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedLinear(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
|
def __init__(self, layer, activation=None, dropout=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# check consistency
|
||||||
|
check_consistency(layer, nn.Module)
|
||||||
|
if activation is not None:
|
||||||
|
check_consistency(activation, nn.Module)
|
||||||
|
if dropout is not None:
|
||||||
|
check_consistency(dropout, float)
|
||||||
|
|
||||||
|
# assign forward
|
||||||
|
if (dropout is None) and (activation is None):
|
||||||
|
self._model = torch.nn.Sequential(layer)
|
||||||
|
|
||||||
|
elif (dropout is None) and (activation is not None):
|
||||||
|
self._model = torch.nn.Sequential(layer,
|
||||||
|
activation)
|
||||||
|
|
||||||
|
elif (dropout is not None) and (activation is None):
|
||||||
|
self._model = torch.nn.Sequential(layer,
|
||||||
|
self._drop(dropout))
|
||||||
|
|
||||||
|
elif (dropout is not None) and (activation is not None):
|
||||||
|
self._model = torch.nn.Sequential(layer,
|
||||||
|
activation,
|
||||||
|
self._drop(dropout))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._model(x)
|
||||||
|
|||||||
22
tests/test_model/test_residualfnn.py
Normal file
22
tests/test_model/test_residualfnn.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from pina.model import ResidualFeedForward
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
# simple constructor
|
||||||
|
ResidualFeedForward(input_dimensions=2, output_dimensions=1)
|
||||||
|
|
||||||
|
# wrong transformer nets (not 2)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ResidualFeedForward(input_dimensions=2, output_dimensions=1, transformer_nets=[torch.nn.Linear(2, 20)])
|
||||||
|
|
||||||
|
# wrong transformer nets (not nn.Module)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ResidualFeedForward(input_dimensions=2, output_dimensions=1, transformer_nets=[2, 2])
|
||||||
|
|
||||||
|
def test_forward():
|
||||||
|
x = torch.rand(10, 2)
|
||||||
|
model = ResidualFeedForward(input_dimensions=2, output_dimensions=1)
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user