Fourier Block and minor fixes
* Adding fourier block 1d/2d/3d * Adding docs to SpectralConvBlock1D/2D/3D and to FourierBlock1D/2D/3D * Adding tests for fourier block
This commit is contained in:
committed by
Nicola Demo
parent
2bf42d5fea
commit
83ecdb0eab
@@ -3,9 +3,13 @@ __all__ = [
|
|||||||
'ResidualBlock',
|
'ResidualBlock',
|
||||||
'SpectralConvBlock1D',
|
'SpectralConvBlock1D',
|
||||||
'SpectralConvBlock2D',
|
'SpectralConvBlock2D',
|
||||||
'SpectralConvBlock3D'
|
'SpectralConvBlock3D',
|
||||||
|
'FourierBlock1D',
|
||||||
|
'FourierBlock2D',
|
||||||
|
'FourierBlock3D',
|
||||||
]
|
]
|
||||||
|
|
||||||
from .convolution_2d import ContinuousConvBlock
|
from .convolution_2d import ContinuousConvBlock
|
||||||
from .residual import ResidualBlock
|
from .residual import ResidualBlock
|
||||||
from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
|
from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
|
||||||
|
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||||
|
|||||||
@@ -2,9 +2,14 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ...utils import check_consistency
|
from ...utils import check_consistency
|
||||||
|
|
||||||
|
from pina.model.layers import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
|
||||||
|
|
||||||
class FourierBlock(nn.Module):
|
|
||||||
"""Fourier block base class. Implementation of a fourier block.
|
class FourierBlock1D(nn.Module):
|
||||||
|
"""
|
||||||
|
Fourier block implementation for three dimensional
|
||||||
|
input tensor. The combination of Fourier blocks
|
||||||
|
make up the Fourier Neural Operator
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@@ -16,9 +21,143 @@ class FourierBlock(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
"""
|
||||||
|
PINA implementation of Fourier block one dimension. The module computes
|
||||||
|
the spectral convolution of the input with a linear kernel in the
|
||||||
|
fourier space, and then it maps the input back to the physical
|
||||||
|
space. The output is then added to a Linear tranformation of the
|
||||||
|
input in the physical space. Finally an activation function is
|
||||||
|
applied to the output.
|
||||||
|
|
||||||
|
The block expects an input of size ``[batch, input_numb_fields, N]``
|
||||||
|
and returns an output of size ``[batch, output_numb_fields, N]``.
|
||||||
|
|
||||||
|
:param int input_numb_fields: The number of channels for the input.
|
||||||
|
:param int output_numb_fields: The number of channels for the output.
|
||||||
|
:param list | tuple n_modes: Number of modes to select for each dimension.
|
||||||
|
It must be at most equal to the ``floor(N/2)+1``.
|
||||||
|
:param torch.nn.Module activation: The activation function.
|
||||||
|
"""
|
||||||
|
# check type consistency
|
||||||
|
check_consistency(activation(), nn.Module)
|
||||||
|
|
||||||
|
# assign variables
|
||||||
|
self._spectral_conv = SpectralConvBlock1D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=n_modes)
|
||||||
|
self._activation = activation()
|
||||||
|
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
pass
|
return self._activation(self._spectral_conv(x) + self._linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class FourierBlock2D(nn.Module):
|
||||||
|
"""
|
||||||
|
Fourier block implementation for two dimensional
|
||||||
|
input tensor. The combination of Fourier blocks
|
||||||
|
make up the Fourier Neural Operator
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Li, Zongyi, et al.
|
||||||
|
"Fourier neural operator for parametric partial
|
||||||
|
differential equations." arXiv preprint
|
||||||
|
arXiv:2010.08895 (2020)
|
||||||
|
<https://arxiv.org/abs/2010.08895.pdf>`_.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
|
||||||
|
"""
|
||||||
|
PINA implementation of Fourier block two dimensions. The module computes
|
||||||
|
the spectral convolution of the input with a linear kernel in the
|
||||||
|
fourier space, and then it maps the input back to the physical
|
||||||
|
space. The output is then added to a Linear tranformation of the
|
||||||
|
input in the physical space. Finally an activation function is
|
||||||
|
applied to the output.
|
||||||
|
|
||||||
|
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
|
||||||
|
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
|
||||||
|
|
||||||
|
:param int input_numb_fields: The number of channels for the input.
|
||||||
|
:param int output_numb_fields: The number of channels for the output.
|
||||||
|
:param list | tuple n_modes: Number of modes to select for each dimension.
|
||||||
|
It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``.
|
||||||
|
:param torch.nn.Module activation: The activation function.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# check type consistency
|
||||||
|
check_consistency(activation(), nn.Module)
|
||||||
|
|
||||||
|
# assign variables
|
||||||
|
self._spectral_conv = SpectralConvBlock2D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=n_modes)
|
||||||
|
self._activation = activation()
|
||||||
|
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shape_x = x.shape
|
||||||
|
ln = self._linear(x.view(shape_x[0], shape_x[1], -1))
|
||||||
|
ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3])
|
||||||
|
return self._activation(self._spectral_conv(x) + ln)
|
||||||
|
|
||||||
|
|
||||||
|
class FourierBlock3D(nn.Module):
|
||||||
|
"""
|
||||||
|
Fourier block implementation for three dimensional
|
||||||
|
input tensor. The combination of Fourier blocks
|
||||||
|
make up the Fourier Neural Operator
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Li, Zongyi, et al.
|
||||||
|
"Fourier neural operator for parametric partial
|
||||||
|
differential equations." arXiv preprint
|
||||||
|
arXiv:2010.08895 (2020)
|
||||||
|
<https://arxiv.org/abs/2010.08895.pdf>`_.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_numb_fields, output_numb_fields, n_modes, activation=torch.nn.Tanh):
|
||||||
|
"""
|
||||||
|
PINA implementation of Fourier block three dimensions. The module computes
|
||||||
|
the spectral convolution of the input with a linear kernel in the
|
||||||
|
fourier space, and then it maps the input back to the physical
|
||||||
|
space. The output is then added to a Linear tranformation of the
|
||||||
|
input in the physical space. Finally an activation function is
|
||||||
|
applied to the output.
|
||||||
|
|
||||||
|
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
|
||||||
|
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
|
||||||
|
|
||||||
|
:param int input_numb_fields: The number of channels for the input.
|
||||||
|
:param int output_numb_fields: The number of channels for the output.
|
||||||
|
:param list | tuple n_modes: Number of modes to select for each dimension.
|
||||||
|
It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1``
|
||||||
|
and ``floor(Nz/2)+1``.
|
||||||
|
:param torch.nn.Module activation: The activation function.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# check type consistency
|
||||||
|
check_consistency(activation(), nn.Module)
|
||||||
|
|
||||||
|
# assign variables
|
||||||
|
self._spectral_conv = SpectralConvBlock3D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=n_modes)
|
||||||
|
self._activation = activation()
|
||||||
|
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shape_x = x.shape
|
||||||
|
ln = self._linear(x.view(shape_x[0], shape_x[1], -1))
|
||||||
|
ln = ln.view(shape_x[0], -1, shape_x[2], shape_x[3], shape_x[4])
|
||||||
|
return self._activation(self._spectral_conv(x) + ln)
|
||||||
|
|||||||
@@ -12,14 +12,18 @@ class SpectralConvBlock1D(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
|
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
|
||||||
"""
|
"""
|
||||||
TODO
|
PINA implementation of spectral convolution. The module computes
|
||||||
|
the spectral convolution of the input with a linear kernel in the
|
||||||
|
fourier space, and then it maps the input back to the physical
|
||||||
|
space.
|
||||||
|
|
||||||
:param input_numb_fields: _description_
|
The block expects an input of size ``[batch, input_numb_fields, N]``
|
||||||
:type input_numb_fields: _type_
|
and returns an output of size ``[batch, output_numb_fields, N]``.
|
||||||
:param output_numb_fields: _description_
|
|
||||||
:type output_numb_fields: _type_
|
:param int input_numb_fields: The number of channels for the input.
|
||||||
:param n_modes: _description_
|
:param int output_numb_fields: The number of channels for the output.
|
||||||
:type n_modes: _type_
|
:param int n_modes: Number of modes to select, it must be at most equal
|
||||||
|
to the ``floor(N/2)+1``.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -69,9 +73,6 @@ class SpectralConvBlock1D(nn.Module):
|
|||||||
"""
|
"""
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
# if x.shape[-1] // 2 + 1 < self._modes:
|
|
||||||
# raise RuntimeError('Number of modes is too high, decrease number of modes.')
|
|
||||||
|
|
||||||
# Compute Fourier transform of the input
|
# Compute Fourier transform of the input
|
||||||
x_ft = torch.fft.rfft(x)
|
x_ft = torch.fft.rfft(x)
|
||||||
|
|
||||||
@@ -95,6 +96,20 @@ class SpectralConvBlock2D(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
|
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
|
||||||
|
"""
|
||||||
|
PINA implementation of spectral convolution. The module computes
|
||||||
|
the spectral convolution of the input with a linear kernel in the
|
||||||
|
fourier space, and then it maps the input back to the physical
|
||||||
|
space.
|
||||||
|
|
||||||
|
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
|
||||||
|
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
|
||||||
|
|
||||||
|
:param int input_numb_fields: The number of channels for the input.
|
||||||
|
:param int output_numb_fields: The number of channels for the output.
|
||||||
|
:param list | tuple n_modes: Number of modes to select for each dimension.
|
||||||
|
It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# check type consistency
|
# check type consistency
|
||||||
@@ -188,16 +203,19 @@ class SpectralConvBlock3D(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
|
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
|
||||||
"""
|
"""
|
||||||
TODO
|
PINA implementation of spectral convolution. The module computes
|
||||||
|
the spectral convolution of the input with a linear kernel in the
|
||||||
|
fourier space, and then it maps the input back to the physical
|
||||||
|
space.
|
||||||
|
|
||||||
:param input_numb_fields: _description_
|
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
|
||||||
:type input_numb_fields: _type_
|
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
|
||||||
:param output_numb_fields: _description_
|
|
||||||
:type output_numb_fields: _type_
|
:param int input_numb_fields: The number of channels for the input.
|
||||||
:param n_modes: _description_
|
:param int output_numb_fields: The number of channels for the output.
|
||||||
:type n_modes: _type_
|
:param list | tuple n_modes: Number of modes to select for each dimension.
|
||||||
:raises ValueError: _description_
|
It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1``
|
||||||
:raises ValueError: _description_
|
and ``floor(Nz/2)+1``.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
43
tests/test_layers/test_fourier.py
Normal file
43
tests/test_layers/test_fourier.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from pina.model.layers import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||||
|
import torch
|
||||||
|
|
||||||
|
input_numb_fields = 3
|
||||||
|
output_numb_fields = 4
|
||||||
|
batch = 5
|
||||||
|
|
||||||
|
def test_constructor_1d():
|
||||||
|
FourierBlock1D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=5)
|
||||||
|
|
||||||
|
def test_forward_1d():
|
||||||
|
sconv = FourierBlock1D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=4)
|
||||||
|
x = torch.rand(batch, input_numb_fields, 10)
|
||||||
|
sconv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor_2d():
|
||||||
|
FourierBlock2D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=[5, 4])
|
||||||
|
|
||||||
|
def test_forward_2d():
|
||||||
|
sconv = FourierBlock2D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=[5, 4])
|
||||||
|
x = torch.rand(batch, input_numb_fields, 10, 10)
|
||||||
|
sconv(x)
|
||||||
|
|
||||||
|
def test_constructor_3d():
|
||||||
|
FourierBlock3D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=[5, 4, 4])
|
||||||
|
|
||||||
|
def test_forward_3d():
|
||||||
|
sconv = FourierBlock3D(input_numb_fields=input_numb_fields,
|
||||||
|
output_numb_fields=output_numb_fields,
|
||||||
|
n_modes=[5, 4, 4])
|
||||||
|
x = torch.rand(batch, input_numb_fields, 10, 10, 10)
|
||||||
|
sconv(x)
|
||||||
Reference in New Issue
Block a user