Layers and Models update PR
* add residual block * add test conv and residual block * modify FFN kwargs
This commit is contained in:
committed by
Nicola Demo
parent
8c16e27ae4
commit
15ecaacb7c
@@ -2,19 +2,17 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from pina.label_tensor import LabelTensor
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(torch.nn.Module):
|
class FeedForward(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
The PINA implementation of feedforward network, also refered as multilayer
|
The PINA implementation of feedforward network, also refered as multilayer
|
||||||
perceptron.
|
perceptron.
|
||||||
|
|
||||||
:param int input_variables: The number of input components of the model.
|
:param int input_dimensons: The number of input components of the model.
|
||||||
Expected tensor shape of the form (*, input_variables), where *
|
Expected tensor shape of the form (*, input_dimensons), where *
|
||||||
means any number of dimensions including none.
|
means any number of dimensions including none.
|
||||||
:param int output_variables: The number of output components of the model.
|
:param int output_dimensions: The number of output components of the model.
|
||||||
Expected tensor shape of the form (*, output_variables), where *
|
Expected tensor shape of the form (*, output_dimensions), where *
|
||||||
means any number of dimensions including none.
|
means any number of dimensions including none.
|
||||||
:param int inner_size: number of neurons in the hidden layer(s). Default is
|
:param int inner_size: number of neurons in the hidden layer(s). Default is
|
||||||
20.
|
20.
|
||||||
@@ -28,20 +26,20 @@ class FeedForward(torch.nn.Module):
|
|||||||
`inner_size` are not considered.
|
`inner_size` are not considered.
|
||||||
:param bool bias: If `True` the MLP will consider some bias.
|
:param bool bias: If `True` the MLP will consider some bias.
|
||||||
"""
|
"""
|
||||||
def __init__(self, input_variables, output_variables, inner_size=20,
|
def __init__(self, input_dimensons, output_dimensions, inner_size=20,
|
||||||
n_layers=2, func=nn.Tanh, layers=None, bias=True):
|
n_layers=2, func=nn.Tanh, layers=None, bias=True):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
if not isinstance(input_variables, int):
|
if not isinstance(input_dimensons, int):
|
||||||
raise ValueError('input_variables expected to be int.')
|
raise ValueError('input_dimensons expected to be int.')
|
||||||
self.input_dimension = input_variables
|
self.input_dimension = input_dimensons
|
||||||
|
|
||||||
if not isinstance(output_variables, int):
|
if not isinstance(output_dimensions, int):
|
||||||
raise ValueError('output_variables expected to be int.')
|
raise ValueError('output_dimensions expected to be int.')
|
||||||
self.output_dimension = output_variables
|
self.output_dimension = output_dimensions
|
||||||
if layers is None:
|
if layers is None:
|
||||||
layers = [inner_size] * n_layers
|
layers = [inner_size] * n_layers
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseContinuousConv',
|
'ContinuousConvBlock',
|
||||||
'ContinuousConv'
|
'ResidualBlock'
|
||||||
]
|
]
|
||||||
|
|
||||||
from .convolution import BaseContinuousConv
|
from .convolution_2d import ContinuousConvBlock
|
||||||
from .convolution_2d import ContinuousConv
|
from .residual import ResidualBlock
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from ..feed_forward import FeedForward
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class ContinuousConv(BaseContinuousConv):
|
class ContinuousConvBlock(BaseContinuousConv):
|
||||||
"""
|
"""
|
||||||
Implementation of Continuous Convolutional operator.
|
Implementation of Continuous Convolutional operator.
|
||||||
|
|
||||||
|
|||||||
24
pina/model/layers/fourier.py
Normal file
24
pina/model/layers/fourier.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ...utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class FourierBlock(nn.Module):
|
||||||
|
"""Fourier block base class. Implementation of a fourier block.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Li, Zongyi, et al.
|
||||||
|
"Fourier neural operator for parametric partial
|
||||||
|
differential equations." arXiv preprint
|
||||||
|
arXiv:2010.08895 (2020)
|
||||||
|
<https://arxiv.org/abs/2010.08895.pdf>`_.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pass
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
""" Integral class for continous convolution"""
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
95
pina/model/layers/residual.py
Normal file
95
pina/model/layers/residual.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ...utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
"""Residual block base class. Implementation of a residual block.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: He, Kaiming, et al.
|
||||||
|
"Deep residual learning for image recognition."
|
||||||
|
Proceedings of the IEEE conference on computer vision
|
||||||
|
and pattern recognition. 2016..
|
||||||
|
<https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim, output_dim,
|
||||||
|
hidden_dim, spectral_norm=False,
|
||||||
|
activation=torch.nn.ReLU()):
|
||||||
|
"""Residual block constructor
|
||||||
|
|
||||||
|
:param int input_dim: Dimension of the input to pass to the
|
||||||
|
feedforward linear layer.
|
||||||
|
:param int output_dim: Dimension of the output from the
|
||||||
|
residual layer.
|
||||||
|
:param int hidden_dim: Hidden dimension for mapping the input
|
||||||
|
(first block).
|
||||||
|
:param bool spectral_norm: Apply spectral normalization to feedforward
|
||||||
|
layers, defaults to False.
|
||||||
|
:param torch.nn.Module activation: Cctivation function after first block.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# check consistency
|
||||||
|
check_consistency(spectral_norm, bool)
|
||||||
|
check_consistency(input_dim, int)
|
||||||
|
check_consistency(output_dim, int)
|
||||||
|
check_consistency(hidden_dim, int)
|
||||||
|
check_consistency(activation, torch.nn.Module)
|
||||||
|
|
||||||
|
# assign variables
|
||||||
|
self._spectral_norm = spectral_norm
|
||||||
|
self._input_dim = input_dim
|
||||||
|
self._output_dim = output_dim
|
||||||
|
self._hidden_dim = hidden_dim
|
||||||
|
self._activation = activation
|
||||||
|
|
||||||
|
# create layers
|
||||||
|
self.l1 = self._spect_norm(nn.Linear(input_dim, hidden_dim))
|
||||||
|
self.l2 = self._spect_norm(nn.Linear(hidden_dim, output_dim))
|
||||||
|
self.l3 = self._spect_norm(nn.Linear(input_dim, output_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass for residual block layer.
|
||||||
|
|
||||||
|
:param torch.Tensor x: Input tensor for the residual layer.
|
||||||
|
:return: Output tensor for the residual layer.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
y = self.activation(self.l1(x))
|
||||||
|
y = self.l2(y)
|
||||||
|
x = self.l3(x)
|
||||||
|
return y + x
|
||||||
|
|
||||||
|
def _spect_norm(self, x):
|
||||||
|
"""Perform spectral norm on the layers.
|
||||||
|
|
||||||
|
:param x: A torch.nn.Module Linear layer
|
||||||
|
:type x: torch.nn.Module
|
||||||
|
:return: The spectral norm of the layer
|
||||||
|
:rtype: torch.nn.Module
|
||||||
|
"""
|
||||||
|
return nn.utils.spectral_norm(x) if self._spectral_norm else x
|
||||||
|
|
||||||
|
@ property
|
||||||
|
def spectral_norm(self):
|
||||||
|
return self._spectral_norm
|
||||||
|
|
||||||
|
@ property
|
||||||
|
def input_dim(self):
|
||||||
|
return self._input_dim
|
||||||
|
|
||||||
|
@ property
|
||||||
|
def output_dim(self):
|
||||||
|
return self._output_dim
|
||||||
|
|
||||||
|
@ property
|
||||||
|
def hidden_dim(self):
|
||||||
|
return self._hidden_dim
|
||||||
|
|
||||||
|
@ property
|
||||||
|
def activation(self):
|
||||||
|
return self._activation
|
||||||
16
pina/model/layers/spectral.py
Normal file
16
pina/model/layers/spectral.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ...utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class SpectralConvBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of spectral convolution block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pass
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from pina.model.layers import ContinuousConv
|
from pina.model.layers import ContinuousConvBlock
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -70,12 +70,12 @@ x = make_grid(x)
|
|||||||
def test_constructor():
|
def test_constructor():
|
||||||
model = MLP
|
model = MLP
|
||||||
|
|
||||||
conv = ContinuousConv(channel_input,
|
conv = ContinuousConvBlock(channel_input,
|
||||||
channel_output,
|
channel_output,
|
||||||
dim,
|
dim,
|
||||||
stride,
|
stride,
|
||||||
model=model)
|
model=model)
|
||||||
conv = ContinuousConv(channel_input,
|
conv = ContinuousConvBlock(channel_input,
|
||||||
channel_output,
|
channel_output,
|
||||||
dim,
|
dim,
|
||||||
stride,
|
stride,
|
||||||
@@ -86,7 +86,7 @@ def test_forward():
|
|||||||
model = MLP
|
model = MLP
|
||||||
|
|
||||||
# simple forward
|
# simple forward
|
||||||
conv = ContinuousConv(channel_input,
|
conv = ContinuousConvBlock(channel_input,
|
||||||
channel_output,
|
channel_output,
|
||||||
dim,
|
dim,
|
||||||
stride,
|
stride,
|
||||||
@@ -94,7 +94,7 @@ def test_forward():
|
|||||||
conv(x)
|
conv(x)
|
||||||
|
|
||||||
# simple forward with optimization
|
# simple forward with optimization
|
||||||
conv = ContinuousConv(channel_input,
|
conv = ContinuousConvBlock(channel_input,
|
||||||
channel_output,
|
channel_output,
|
||||||
dim,
|
dim,
|
||||||
stride,
|
stride,
|
||||||
@@ -107,13 +107,13 @@ def test_transpose():
|
|||||||
model = MLP
|
model = MLP
|
||||||
|
|
||||||
# simple transpose
|
# simple transpose
|
||||||
conv = ContinuousConv(channel_input,
|
conv = ContinuousConvBlock(channel_input,
|
||||||
channel_output,
|
channel_output,
|
||||||
dim,
|
dim,
|
||||||
stride,
|
stride,
|
||||||
model=model)
|
model=model)
|
||||||
|
|
||||||
conv2 = ContinuousConv(channel_output,
|
conv2 = ContinuousConvBlock(channel_output,
|
||||||
channel_input,
|
channel_input,
|
||||||
dim,
|
dim,
|
||||||
stride,
|
stride,
|
||||||
@@ -122,13 +122,13 @@ def test_transpose():
|
|||||||
integrals = conv(x)
|
integrals = conv(x)
|
||||||
conv2.transpose(integrals[..., -1], x)
|
conv2.transpose(integrals[..., -1], x)
|
||||||
|
|
||||||
stride_no_overlap = {"domain": [10, 10],
|
# stride_no_overlap = {"domain": [10, 10],
|
||||||
"start": [0, 0],
|
# "start": [0, 0],
|
||||||
"jumps": dim,
|
# "jumps": dim,
|
||||||
"direction": [1, 1.]}
|
# "direction": [1, 1.]}
|
||||||
|
|
||||||
# simple transpose with optimization
|
## simple transpose with optimization
|
||||||
# conv = ContinuousConv(channel_input,
|
# conv = ContinuousConvBlock(channel_input,
|
||||||
# channel_output,
|
# channel_output,
|
||||||
# dim,
|
# dim,
|
||||||
# stride_no_overlap,
|
# stride_no_overlap,
|
||||||
@@ -137,4 +137,4 @@ def test_transpose():
|
|||||||
# no_overlap=True)
|
# no_overlap=True)
|
||||||
|
|
||||||
# integrals = conv(x)
|
# integrals = conv(x)
|
||||||
# conv.transpose(integrals[..., -1], x)
|
# conv.transpose(integrals[..., -1], x)
|
||||||
26
tests/test_layers/test_residual.py
Normal file
26
tests/test_layers/test_residual.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from pina.model.layers import ResidualBlock
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
|
||||||
|
res_block = ResidualBlock(input_dim=10,
|
||||||
|
output_dim=3,
|
||||||
|
hidden_dim=4)
|
||||||
|
|
||||||
|
res_block = ResidualBlock(input_dim=10,
|
||||||
|
output_dim=3,
|
||||||
|
hidden_dim=4,
|
||||||
|
spectral_norm=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward():
|
||||||
|
|
||||||
|
res_block = ResidualBlock(input_dim=10,
|
||||||
|
output_dim=3,
|
||||||
|
hidden_dim=4)
|
||||||
|
|
||||||
|
x = torch.rand(size=(80, 10))
|
||||||
|
y = res_block(x)
|
||||||
|
assert y.shape[1]==3
|
||||||
|
assert y.shape[0]==x.shape[0]
|
||||||
Reference in New Issue
Block a user