Unifying integral kernel NO architectures (#239)
* Unify integral kernel NO architectures with NeuralKernelOperator * Implement FNO based on NeuralKernelOperator * modify doc for FNO and add for FourierIntegralKernel, NeuralKernelOperator * adding tests --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
@@ -48,11 +48,13 @@ Models
|
|||||||
:maxdepth: 5
|
:maxdepth: 5
|
||||||
|
|
||||||
Network <models/network.rst>
|
Network <models/network.rst>
|
||||||
|
KernelNeuralOperator <models/base_no.rst>
|
||||||
FeedForward <models/fnn.rst>
|
FeedForward <models/fnn.rst>
|
||||||
MultiFeedForward <models/multifeedforward.rst>
|
MultiFeedForward <models/multifeedforward.rst>
|
||||||
ResidualFeedForward <models/fnn_residual.rst>
|
ResidualFeedForward <models/fnn_residual.rst>
|
||||||
DeepONet <models/deeponet.rst>
|
DeepONet <models/deeponet.rst>
|
||||||
MIONet <models/mionet.rst>
|
MIONet <models/mionet.rst>
|
||||||
|
FourierIntegralKernel <models/fourier_kernel.rst>
|
||||||
FNO <models/fno.rst>
|
FNO <models/fno.rst>
|
||||||
|
|
||||||
Layers
|
Layers
|
||||||
|
|||||||
7
docs/source/_rst/models/base_no.rst
Normal file
7
docs/source/_rst/models/base_no.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
KernelNeuralOperator
|
||||||
|
=======================
|
||||||
|
.. currentmodule:: pina.model.base_no
|
||||||
|
|
||||||
|
.. autoclass:: KernelNeuralOperator
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
7
docs/source/_rst/models/fourier_kernel.rst
Normal file
7
docs/source/_rst/models/fourier_kernel.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
FourierIntegralKernel
|
||||||
|
=========================
|
||||||
|
.. currentmodule:: pina.model.fno
|
||||||
|
|
||||||
|
.. autoclass:: FourierIntegralKernel
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -1,13 +1,16 @@
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"FeedForward",
|
'FeedForward',
|
||||||
"ResidualFeedForward",
|
'ResidualFeedForward',
|
||||||
"MultiFeedForward",
|
'MultiFeedForward',
|
||||||
"DeepONet",
|
'DeepONet',
|
||||||
"MIONet",
|
'MIONet',
|
||||||
"FNO",
|
'FNO',
|
||||||
|
'FourierIntegralKernel',
|
||||||
|
'KernelNeuralOperator'
|
||||||
]
|
]
|
||||||
|
|
||||||
from .feed_forward import FeedForward, ResidualFeedForward
|
from .feed_forward import FeedForward, ResidualFeedForward
|
||||||
from .multi_feed_forward import MultiFeedForward
|
from .multi_feed_forward import MultiFeedForward
|
||||||
from .deeponet import DeepONet, MIONet
|
from .deeponet import DeepONet, MIONet
|
||||||
from .fno import FNO
|
from .fno import FNO, FourierIntegralKernel
|
||||||
|
from .base_no import KernelNeuralOperator
|
||||||
|
|||||||
136
pina/model/base_no.py
Normal file
136
pina/model/base_no.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Kernel Neural Operator Module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class KernelNeuralOperator(torch.nn.Module):
|
||||||
|
r"""
|
||||||
|
Base class for composing Neural Operators with integral kernels.
|
||||||
|
|
||||||
|
This is a base class for composing neural operators with multiple
|
||||||
|
integral kernels. All neural operator models defined in PINA inherit
|
||||||
|
from this class. The structure is inspired by the work of Kovachki, N.
|
||||||
|
et al. see Figure 2 of the reference for extra details. The Neural
|
||||||
|
Operators inheriting from this class can be written as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
G_\theta := P \circ K_m \circ \cdot \circ K_1 \circ L
|
||||||
|
|
||||||
|
where:
|
||||||
|
|
||||||
|
* :math:`G_\theta: \mathcal{A}\subset \mathbb{R}^{\rm{in}} \rightarrow
|
||||||
|
\mathcal{D}\subset \mathbb{R}^{\rm{out}}` is the neural operator
|
||||||
|
approximation of the unknown real operator :math:`G`, that is
|
||||||
|
:math:`G \approx G_\theta`
|
||||||
|
* :math:`L: \mathcal{A}\subset \mathbb{R}^{\rm{in}} \rightarrow
|
||||||
|
\mathbb{R}^{\rm{emb}}` is a lifting operator mapping the input
|
||||||
|
from its domain :math:`\mathcal{A}\subset \mathbb{R}^{\rm{in}}`
|
||||||
|
to its embedding dimension :math:`\mathbb{R}^{\rm{emb}}`
|
||||||
|
* :math:`\{K_i : \mathbb{R}^{\rm{emb}} \rightarrow
|
||||||
|
\mathbb{R}^{\rm{emb}} \}_{i=1}^m` are :math:`m` integral kernels
|
||||||
|
mapping each hidden representation to the next one.
|
||||||
|
* :math:`P : \mathbb{R}^{\rm{emb}} \rightarrow \mathcal{D}\subset
|
||||||
|
\mathbb{R}^{\rm{out}}` is a projection operator mapping the hidden
|
||||||
|
representation to the output function.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Kovachki, N., Li, Z., Liu, B.,
|
||||||
|
Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A.
|
||||||
|
(2023). *Neural operator: Learning maps between function
|
||||||
|
spaces with applications to PDEs*. Journal of Machine Learning
|
||||||
|
Research, 24(89), 1-97.
|
||||||
|
"""
|
||||||
|
def __init__(self, lifting_operator, integral_kernels, projection_operator):
|
||||||
|
"""
|
||||||
|
:param torch.nn.Module lifting_operator: The lifting operator
|
||||||
|
mapping the input to its hidden dimension.
|
||||||
|
:param torch.nn.Module integral_kernels: List of integral kernels
|
||||||
|
mapping each hidden representation to the next one.
|
||||||
|
:param torch.nn.Module projection_operator: The projection operator
|
||||||
|
mapping the hidden representation to the output function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._lifting_operator = lifting_operator
|
||||||
|
self._integral_kernels = integral_kernels
|
||||||
|
self._projection_operator = projection_operator
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lifting_operator(self):
|
||||||
|
"""
|
||||||
|
The lifting operator property.
|
||||||
|
"""
|
||||||
|
return self._lifting_operator
|
||||||
|
|
||||||
|
@lifting_operator.setter
|
||||||
|
def lifting_operator(self, value):
|
||||||
|
"""
|
||||||
|
The lifting operator setter
|
||||||
|
|
||||||
|
:param torch.nn.Module value: The lifting operator torch module.
|
||||||
|
"""
|
||||||
|
check_consistency(value, torch.nn.Module)
|
||||||
|
self._lifting_operator = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def projection_operator(self):
|
||||||
|
"""
|
||||||
|
The projection operator property.
|
||||||
|
"""
|
||||||
|
return self._projection_operator
|
||||||
|
|
||||||
|
@projection_operator.setter
|
||||||
|
def projection_operator(self, value):
|
||||||
|
"""
|
||||||
|
The projection operator setter
|
||||||
|
|
||||||
|
:param torch.nn.Module value: The projection operator torch module.
|
||||||
|
"""
|
||||||
|
check_consistency(value, torch.nn.Module)
|
||||||
|
self._projection_operator = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def integral_kernels(self):
|
||||||
|
"""
|
||||||
|
The integral kernels operator property.
|
||||||
|
"""
|
||||||
|
return self._integral_kernels
|
||||||
|
|
||||||
|
@integral_kernels.setter
|
||||||
|
def integral_kernels(self, value):
|
||||||
|
"""
|
||||||
|
The integral kernels operator setter
|
||||||
|
|
||||||
|
:param torch.nn.Module value: The integral kernels operator torch
|
||||||
|
module.
|
||||||
|
"""
|
||||||
|
check_consistency(value, torch.nn.Module)
|
||||||
|
self._integral_kernels = value
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r"""
|
||||||
|
Forward computation for Base Neural Operator. It performs a
|
||||||
|
lifting of the input by the ``lifting_operator``.
|
||||||
|
Then different layers integral kernels are applied using
|
||||||
|
``integral_kernels``. Finally the output is projected
|
||||||
|
to the final dimensionality by the ``projection_operator``.
|
||||||
|
|
||||||
|
:param torch.Tensor x: The input tensor for performing the
|
||||||
|
computation. It expects a tensor :math:`B \times N \times D`,
|
||||||
|
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||||
|
in the mesh, :math:`D` the dimension of the problem. In particular
|
||||||
|
:math:`D` is the number of spatial/paramtric/temporal variables
|
||||||
|
plus the field variables. For example for 2D problems with 2
|
||||||
|
output\ variables :math:`D=4`.
|
||||||
|
:return: The output tensor obtained from the NO.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
x = self.lifting_operator(x)
|
||||||
|
x = self.integral_kernels(x)
|
||||||
|
x = self.projection_operator(x)
|
||||||
|
return x
|
||||||
@@ -1,34 +1,36 @@
|
|||||||
|
"""
|
||||||
|
Fourier Neural Operator Module.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ..utils import check_consistency
|
|
||||||
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
import warnings
|
import warnings
|
||||||
|
from ..utils import check_consistency
|
||||||
|
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||||
|
from .base_no import KernelNeuralOperator
|
||||||
|
|
||||||
|
|
||||||
class FNO(torch.nn.Module):
|
class FourierIntegralKernel(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
The PINA implementation of Fourier Neural Operator network.
|
Implementation of Fourier Integral Kernel network.
|
||||||
|
|
||||||
Fourier Neural Operator (FNO) is a general architecture for learning Operators.
|
This class implements the Fourier Integral Kernel network, which is a
|
||||||
Unlike traditional machine learning methods FNO is designed to map
|
PINA implementation of Fourier Neural Operator kernel network.
|
||||||
entire functions to other functions. It can be trained both with
|
It performs global convolution by operating in the Fourier space.
|
||||||
Supervised learning strategies. FNO does global convolution by performing the
|
|
||||||
operation on the Fourier space.
|
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B.,
|
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli,
|
||||||
Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). *Fourier neural operator for
|
K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A.
|
||||||
parametric partial differential equations*.
|
(2020). *Fourier neural operator for parametric partial
|
||||||
|
differential equations*.
|
||||||
DOI: `arXiv preprint arXiv:2010.08895.
|
DOI: `arXiv preprint arXiv:2010.08895.
|
||||||
<https://arxiv.org/abs/2010.08895>`_
|
<https://arxiv.org/abs/2010.08895>`_
|
||||||
"""
|
"""
|
||||||
|
def __init__(self,
|
||||||
def __init__(
|
input_numb_fields,
|
||||||
self,
|
output_numb_fields,
|
||||||
lifting_net,
|
|
||||||
projecting_net,
|
|
||||||
n_modes,
|
n_modes,
|
||||||
dimensions=3,
|
dimensions=3,
|
||||||
padding=8,
|
padding=8,
|
||||||
@@ -36,19 +38,29 @@ class FNO(torch.nn.Module):
|
|||||||
inner_size=20,
|
inner_size=20,
|
||||||
n_layers=2,
|
n_layers=2,
|
||||||
func=nn.Tanh,
|
func=nn.Tanh,
|
||||||
layers=None,
|
layers=None):
|
||||||
):
|
"""
|
||||||
|
:param int input_numb_fields: Number of input fields.
|
||||||
|
:param int output_numb_fields: Number of output fields.
|
||||||
|
:param int | list[int] n_modes: Number of modes.
|
||||||
|
:param int dimensions: Number of dimensions (1, 2, or 3).
|
||||||
|
:param int padding: Padding size, defaults to 8.
|
||||||
|
:param str padding_type: Type of padding, defaults to "constant".
|
||||||
|
:param int inner_size: Inner size, defaults to 20.
|
||||||
|
:param int n_layers: Number of layers, defaults to 2.
|
||||||
|
:param torch.nn.Module func: Activation function, defaults to nn.Tanh.
|
||||||
|
:param list[int] layers: List of layer sizes, defaults to None.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# check type consistency
|
# check type consistency
|
||||||
check_consistency(lifting_net, nn.Module)
|
|
||||||
check_consistency(projecting_net, nn.Module)
|
|
||||||
check_consistency(dimensions, int)
|
check_consistency(dimensions, int)
|
||||||
check_consistency(padding, int)
|
check_consistency(padding, int)
|
||||||
check_consistency(padding_type, str)
|
check_consistency(padding_type, str)
|
||||||
check_consistency(inner_size, int)
|
check_consistency(inner_size, int)
|
||||||
check_consistency(n_layers, int)
|
check_consistency(n_layers, int)
|
||||||
check_consistency(func, nn.Module, subclass=True)
|
check_consistency(func, nn.Module, subclass=True)
|
||||||
|
|
||||||
if layers is not None:
|
if layers is not None:
|
||||||
if isinstance(layers, (tuple, list)):
|
if isinstance(layers, (tuple, list)):
|
||||||
check_consistency(layers, int)
|
check_consistency(layers, int)
|
||||||
@@ -57,13 +69,9 @@ class FNO(torch.nn.Module):
|
|||||||
if not isinstance(n_modes, (list, tuple, int)):
|
if not isinstance(n_modes, (list, tuple, int)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"n_modes must be a int or list or tuple of valid modes."
|
"n_modes must be a int or list or tuple of valid modes."
|
||||||
" More information on the official documentation."
|
" More information on the official documentation.")
|
||||||
)
|
|
||||||
|
|
||||||
# assign variables
|
# assign padding
|
||||||
# TODO check input lifting net and input projecting net
|
|
||||||
self._lifting_net = lifting_net
|
|
||||||
self._projecting_net = projecting_net
|
|
||||||
self._padding = padding
|
self._padding = padding
|
||||||
|
|
||||||
# initialize fourier layer for each dimension
|
# initialize fourier layer for each dimension
|
||||||
@@ -74,9 +82,11 @@ class FNO(torch.nn.Module):
|
|||||||
elif dimensions == 3:
|
elif dimensions == 3:
|
||||||
fourier_layer = FourierBlock3D
|
fourier_layer = FourierBlock3D
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
|
raise NotImplementedError(
|
||||||
|
"FNO implemented only for 1D/2D/3D data."
|
||||||
|
)
|
||||||
|
|
||||||
# Here we build the FNO by stacking Fourier Blocks
|
# Here we build the FNO kernels by stacking Fourier Blocks
|
||||||
|
|
||||||
# 1. Assign output dimensions for each FNO layer
|
# 1. Assign output dimensions for each FNO layer
|
||||||
if layers is None:
|
if layers is None:
|
||||||
@@ -86,43 +96,33 @@ class FNO(torch.nn.Module):
|
|||||||
if isinstance(func, list):
|
if isinstance(func, list):
|
||||||
if len(layers) != len(func):
|
if len(layers) != len(func):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Uncosistent number of layers and functions."
|
'Uncosistent number of layers and functions.')
|
||||||
)
|
_functions = func
|
||||||
self._functions = func
|
|
||||||
else:
|
else:
|
||||||
self._functions = [func for _ in range(len(layers))]
|
_functions = [func for _ in range(len(layers) - 1)]
|
||||||
|
_functions.append(torch.nn.Identity)
|
||||||
|
|
||||||
# 3. Assign modes functions for each FNO layer
|
# 3. Assign modes functions for each FNO layer
|
||||||
if isinstance(n_modes, list):
|
if isinstance(n_modes, list):
|
||||||
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(
|
if all(isinstance(i, list)
|
||||||
n_modes
|
for i in n_modes) and len(layers) != len(n_modes):
|
||||||
):
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Uncosistent number of layers and functions."
|
"Uncosistent number of layers and functions.")
|
||||||
)
|
|
||||||
elif all(isinstance(i, int) for i in n_modes):
|
elif all(isinstance(i, int) for i in n_modes):
|
||||||
n_modes = [n_modes] * len(layers)
|
n_modes = [n_modes] * len(layers)
|
||||||
else:
|
else:
|
||||||
n_modes = [n_modes] * len(layers)
|
n_modes = [n_modes] * len(layers)
|
||||||
|
|
||||||
# 4. Build the FNO network
|
# 4. Build the FNO network
|
||||||
tmp_layers = layers.copy()
|
_layers = []
|
||||||
first_parameter = next(lifting_net.parameters())
|
tmp_layers = [input_numb_fields] + layers + [output_numb_fields]
|
||||||
input_shape = first_parameter.size()
|
for i in range(len(layers)):
|
||||||
out_feats = lifting_net(torch.rand(size=input_shape)).shape[-1]
|
_layers.append(
|
||||||
tmp_layers.insert(0, out_feats)
|
fourier_layer(input_numb_fields=tmp_layers[i],
|
||||||
|
|
||||||
self._layers = []
|
|
||||||
for i in range(len(tmp_layers) - 1):
|
|
||||||
self._layers.append(
|
|
||||||
fourier_layer(
|
|
||||||
input_numb_fields=tmp_layers[i],
|
|
||||||
output_numb_fields=tmp_layers[i + 1],
|
output_numb_fields=tmp_layers[i + 1],
|
||||||
n_modes=n_modes[i],
|
n_modes=n_modes[i],
|
||||||
activation=self._functions[i],
|
activation=_functions[i]))
|
||||||
)
|
self._layers = nn.Sequential(*_layers)
|
||||||
)
|
|
||||||
self._layers = nn.Sequential(*self._layers)
|
|
||||||
|
|
||||||
# 5. Padding values for spectral conv
|
# 5. Padding values for spectral conv
|
||||||
if isinstance(padding, int):
|
if isinstance(padding, int):
|
||||||
@@ -140,23 +140,22 @@ class FNO(torch.nn.Module):
|
|||||||
of Fourier Blocks are applied. Finally the output is projected
|
of Fourier Blocks are applied. Finally the output is projected
|
||||||
to the final dimensionality by the ``projecting_net``.
|
to the final dimensionality by the ``projecting_net``.
|
||||||
|
|
||||||
:param torch.Tensor x: The input tensor for fourier block, depending on
|
:param torch.Tensor x: The input tensor for fourier block,
|
||||||
``dimension`` in the initialization. In particular it is expected
|
depending on ``dimension`` in the initialization.
|
||||||
|
In particular it is expected:
|
||||||
|
|
||||||
* 1D tensors: ``[batch, X, channels]``
|
* 1D tensors: ``[batch, X, channels]``
|
||||||
* 2D tensors: ``[batch, X, Y, channels]``
|
* 2D tensors: ``[batch, X, Y, channels]``
|
||||||
* 3D tensors: ``[batch, X, Y, Z, channels]``
|
* 3D tensors: ``[batch, X, Y, Z, channels]``
|
||||||
:return: The output tensor obtained from the FNO.
|
:return: The output tensor obtained from the kernels convolution.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
if isinstance(x, LabelTensor): #TODO remove when Network is fixed
|
if isinstance(x, LabelTensor): #TODO remove when Network is fixed
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"LabelTensor passed as input is not allowed, casting LabelTensor to Torch.Tensor"
|
'LabelTensor passed as input is not allowed,'
|
||||||
|
' casting LabelTensor to Torch.Tensor'
|
||||||
)
|
)
|
||||||
x = x.as_subclass(torch.Tensor)
|
x = x.as_subclass(torch.Tensor)
|
||||||
|
|
||||||
# lifting the input in higher dimensional space
|
|
||||||
x = self._lifting_net(x)
|
|
||||||
|
|
||||||
# permuting the input [batch, channels, x, y, ...]
|
# permuting the input [batch, channels, x, y, ...]
|
||||||
permutation_idx = [0, x.ndim - 1, *[i for i in range(1, x.ndim - 1)]]
|
permutation_idx = [0, x.ndim - 1, *[i for i in range(1, x.ndim - 1)]]
|
||||||
x = x.permute(permutation_idx)
|
x = x.permute(permutation_idx)
|
||||||
@@ -175,5 +174,85 @@ class FNO(torch.nn.Module):
|
|||||||
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
|
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
|
||||||
x = x.permute(permutation_idx)
|
x = x.permute(permutation_idx)
|
||||||
|
|
||||||
# apply projecting operator and return
|
return x
|
||||||
return self._projecting_net(x)
|
|
||||||
|
|
||||||
|
class FNO(KernelNeuralOperator):
|
||||||
|
"""
|
||||||
|
The PINA implementation of Fourier Neural Operator network.
|
||||||
|
|
||||||
|
Fourier Neural Operator (FNO) is a general architecture for
|
||||||
|
learning Operators. Unlike traditional machine learning methods
|
||||||
|
FNO is designed to map entire functions to other functions. It
|
||||||
|
can be trained with Supervised learning strategies. FNO does global
|
||||||
|
convolution by performing the operation on the Fourier space.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli,
|
||||||
|
K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A.
|
||||||
|
(2020). *Fourier neural operator for parametric partial
|
||||||
|
differential equations*.
|
||||||
|
DOI: `arXiv preprint arXiv:2010.08895.
|
||||||
|
<https://arxiv.org/abs/2010.08895>`_
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
lifting_net,
|
||||||
|
projecting_net,
|
||||||
|
n_modes,
|
||||||
|
dimensions=3,
|
||||||
|
padding=8,
|
||||||
|
padding_type="constant",
|
||||||
|
inner_size=20,
|
||||||
|
n_layers=2,
|
||||||
|
func=nn.Tanh,
|
||||||
|
layers=None):
|
||||||
|
"""
|
||||||
|
:param torch.nn.Module lifting_net: The neural network for lifting
|
||||||
|
the input.
|
||||||
|
:param torch.nn.Module projecting_net: The neural network for
|
||||||
|
projecting the output.
|
||||||
|
:param int | list[int] n_modes: Number of modes.
|
||||||
|
:param int dimensions: Number of dimensions (1, 2, or 3).
|
||||||
|
:param int padding: Padding size, defaults to 8.
|
||||||
|
:param str padding_type: Type of padding, defaults to `constant`.
|
||||||
|
:param int inner_size: Inner size, defaults to 20.
|
||||||
|
:param int n_layers: Number of layers, defaults to 2.
|
||||||
|
:param torch.nn.Module func: Activation function, defaults to nn.Tanh.
|
||||||
|
:param list[int] layers: List of layer sizes, defaults to None.
|
||||||
|
"""
|
||||||
|
lifting_operator_out = lifting_net(
|
||||||
|
torch.rand(size=next(lifting_net.parameters()).size())).shape[-1]
|
||||||
|
super().__init__(lifting_operator=lifting_net,
|
||||||
|
projection_operator=projecting_net,
|
||||||
|
integral_kernels=FourierIntegralKernel(
|
||||||
|
input_numb_fields=lifting_operator_out,
|
||||||
|
output_numb_fields=next(
|
||||||
|
projecting_net.parameters()).size(),
|
||||||
|
n_modes=n_modes,
|
||||||
|
dimensions=dimensions,
|
||||||
|
padding=padding,
|
||||||
|
padding_type=padding_type,
|
||||||
|
inner_size=inner_size,
|
||||||
|
n_layers=n_layers,
|
||||||
|
func=func,
|
||||||
|
layers=layers))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward computation for Fourier Neural Operator. It performs a
|
||||||
|
lifting of the input by the ``lifting_net``. Then different layers
|
||||||
|
of Fourier Blocks are applied. Finally the output is projected
|
||||||
|
to the final dimensionality by the ``projecting_net``.
|
||||||
|
|
||||||
|
:param torch.Tensor x: The input tensor for fourier block,
|
||||||
|
depending on ``dimension`` in the initialization. In
|
||||||
|
particular it is expected:
|
||||||
|
|
||||||
|
* 1D tensors: ``[batch, X, channels]``
|
||||||
|
* 2D tensors: ``[batch, X, Y, channels]``
|
||||||
|
* 3D tensors: ``[batch, X, Y, Z, channels]``
|
||||||
|
:return: The output tensor obtained from FNO.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
return super().forward(x)
|
||||||
|
|||||||
40
tests/test_model/test_base_no.py
Normal file
40
tests/test_model/test_base_no.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
from pina.model import KernelNeuralOperator, FeedForward
|
||||||
|
|
||||||
|
input_dim = 2
|
||||||
|
output_dim = 4
|
||||||
|
embedding_dim = 24
|
||||||
|
batch_size = 10
|
||||||
|
numb = 256
|
||||||
|
data = torch.rand(size=(batch_size, numb, input_dim), requires_grad=True)
|
||||||
|
output_shape = torch.Size([batch_size, numb, output_dim])
|
||||||
|
|
||||||
|
|
||||||
|
lifting_operator = FeedForward(input_dimensions=input_dim, output_dimensions=embedding_dim)
|
||||||
|
projection_operator = FeedForward(input_dimensions=embedding_dim, output_dimensions=output_dim)
|
||||||
|
integral_kernels = torch.nn.Sequential(FeedForward(input_dimensions=embedding_dim,
|
||||||
|
output_dimensions=embedding_dim),
|
||||||
|
FeedForward(input_dimensions=embedding_dim,
|
||||||
|
output_dimensions=embedding_dim),)
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
KernelNeuralOperator(lifting_operator=lifting_operator,
|
||||||
|
integral_kernels=integral_kernels,
|
||||||
|
projection_operator=projection_operator)
|
||||||
|
|
||||||
|
def test_forward():
|
||||||
|
operator = KernelNeuralOperator(lifting_operator=lifting_operator,
|
||||||
|
integral_kernels=integral_kernels,
|
||||||
|
projection_operator=projection_operator)
|
||||||
|
out = operator(data)
|
||||||
|
assert out.shape == output_shape
|
||||||
|
|
||||||
|
def test_backward():
|
||||||
|
operator = KernelNeuralOperator(lifting_operator=lifting_operator,
|
||||||
|
integral_kernels=integral_kernels,
|
||||||
|
projection_operator=projection_operator)
|
||||||
|
out = operator(data)
|
||||||
|
loss = torch.nn.functional.mse_loss(out, torch.zeros_like(out))
|
||||||
|
loss.backward()
|
||||||
|
grad = data.grad
|
||||||
|
assert grad.shape == data.shape
|
||||||
Reference in New Issue
Block a user