fix doc model part 1

This commit is contained in:
giovanni
2025-03-14 12:24:27 +01:00
committed by Nicola Demo
parent def8f5a1d3
commit 8dc682c849
10 changed files with 676 additions and 433 deletions

View File

@@ -1,5 +1,5 @@
"""
Module containing the neural network models.
Module for the Neural model classes.
"""
__all__ = [

View File

@@ -1,4 +1,4 @@
"""Module Averaging Neural Operator."""
"""Module for the Averaging Neural Operator model class."""
import torch
from torch import nn
@@ -9,19 +9,17 @@ from ..utils import check_consistency
class AveragingNeuralOperator(KernelNeuralOperator):
"""
Implementation of Averaging Neural Operator.
Averaging Neural Operator model class.
Averaging Neural Operator is a general architecture for
learning Operators. Unlike traditional machine learning methods
AveragingNeuralOperator is designed to map entire functions
to other functions. It can be trained with Supervised learning strategies.
AveragingNeuralOperator does convolution by performing a field average.
The Averaging Neural Operator is a general architecture for learning
operators, which map functions to functions. It can be trained both with
Supervised and Physics-Informed learning strategies. The Averaging Neural
Operator performs convolution by means of a field average.
.. seealso::
**Original reference**: Lanthaler S. Li, Z., Kovachki,
Stuart, A. (2020). *The Nonlocal Neural Operator:
Universal Approximation*.
**Original reference**: Lanthaler S., Li, Z., Stuart, A. (2020).
*The Nonlocal Neural Operator: Universal Approximation*.
DOI: `arXiv preprint arXiv:2304.13221.
<https://arxiv.org/abs/2304.13221>`_
"""
@@ -36,21 +34,26 @@ class AveragingNeuralOperator(KernelNeuralOperator):
func=nn.GELU,
):
"""
:param torch.nn.Module lifting_net: The neural network for lifting
the input. It must take as input the input field and the coordinates
at which the input field is avaluated. The output of the lifting
net is chosen as embedding dimension of the problem
:param torch.nn.Module projecting_net: The neural network for
projecting the output. It must take as input the embedding dimension
(output of the ``lifting_net``) plus the dimension
of the coordinates.
:param list[str] field_indices: the label of the fields
in the input tensor.
:param list[str] coordinates_indices: the label of the
coordinates in the input tensor.
:param int n_layers: number of hidden layers. Default is 4.
:param torch.nn.Module func: the activation function to use,
default to torch.nn.GELU.
Initialization of the :class:`AveragingNeuralOperator` class.
:param torch.nn.Module lifting_net: The lifting neural network mapping
the input to its hidden dimension. It must take as input the input
field and the coordinates at which the input field is evaluated.
:param torch.nn.Module projecting_net: The projection neural network
mapping the hidden representation to the output function. It must
take as input the embedding dimension plus the dimension of the
coordinates.
:param list[str] field_indices: The labels of the fields in the input
tensor.
:param list[str] coordinates_indices: The labels of the coordinates in
the input tensor.
:param int n_layers: The number of hidden layers. Default is ``4``.
:param torch.nn.Module func: The activation function to use.
Default is :class:`torch.nn.GELU`.
:raises ValueError: If the input dimension does not match with the
labels of the fields and coordinates.
:raises ValueError: If the input dimension of the projecting network
does not match with the hidden dimension of the lifting network.
"""
# check consistency
@@ -93,19 +96,20 @@ class AveragingNeuralOperator(KernelNeuralOperator):
def forward(self, x):
r"""
Forward computation for Averaging Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of Averaging Neural Operator Blocks are applied.
Finally the output is projected to the final dimensionality
by the ``projecting_net``.
Forward pass for the :class:`AveragingNeuralOperator` model.
:param torch.Tensor x: The input tensor for fourier block,
depending on ``dimension`` in the initialization. It expects
a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem, i.e. the sum
of ``len(coordinates_indices)+len(field_indices)``.
:return: The output tensor obtained from Average Neural Operator.
The ``lifting_net`` maps the input to the hidden dimension.
Then, several layers of
:class:`~pina.model.block.average_neural_operator_block.AVNOBlock` are
applied. Finally, the ``projection_net`` maps the hidden representation
to the output function.
:param LabelTensor x: The input tensor for performing the computation.
It expects a tensor :math:`B \times N \times D`, where :math:`B` is
the batch_size, :math:`N` the number of points in the mesh,
:math:`D` the dimension of the problem, i.e. the sum
of ``len(coordinates_indices)`` and ``len(field_indices)``.
:return: The output tensor.
:rtype: torch.Tensor
"""
points_tmp = x.extract(self.coordinates_indices)

View File

@@ -1,4 +1,4 @@
"""Module for DeepONet model"""
"""Module for the DeepONet and MIONet model classes"""
from functools import partial
import torch
@@ -8,22 +8,18 @@ from ..utils import check_consistency, is_function
class MIONet(torch.nn.Module):
"""
The PINA implementation of MIONet network.
MIONet model class.
MIONet is a general architecture for learning Operators defined
on the tensor product of Banach spaces. Unlike traditional machine
learning methods MIONet is designed to map entire functions to other
functions. It can be trained both with Physics Informed or Supervised
learning strategies.
The MIONet is a general architecture for learning operators, which map
functions to functions. It can be trained with both Supervised and
Physics-Informed learning strategies.
.. seealso::
**Original reference**: Jin, Pengzhan, Shuai Meng, and Lu Lu.
**Original reference**: Jin, P., Meng, S., and Lu L. (2022).
*MIONet: Learning multiple-input operators via tensor product.*
SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351
DOI: `10.1137/22M1477751
<https://doi.org/10.1137/22M1477751>`_
DOI: `10.1137/22M1477751 <https://doi.org/10.1137/22M1477751>`_
"""
def __init__(
@@ -35,42 +31,44 @@ class MIONet(torch.nn.Module):
translation=True,
):
"""
:param dict networks: The neural networks to use as
models. The ``dict`` takes as key a neural network, and
as value the list of indeces to extract from the input variable
in the forward pass of the neural network. If a list of ``int``
is passed, the corresponding columns of the inner most entries are
extracted.
If a list of ``str`` is passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor`are extracted. The
``torch.nn.Module`` model has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
Default implementation consist of different branch nets and one
Initialization of the :class:`MIONet` class.
:param dict networks: The neural networks to use as models. The ``dict``
takes as key a neural network, and as value the list of indeces to
extract from the input variable in the forward pass of the neural
network. If a ``list[int]`` is passed, the corresponding columns of
the inner most entries are extracted. If a ``list[str]`` is passed
the variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
Each :class:`torch.nn.Module` model has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
Default implementation consists of several branch nets and one
trunk nets.
:param str or Callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. Available aggregators include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max:
``max``.
:param str or Callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. Available reductions include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max:
``max``.
:param bool or Callable scale: Scaling the final output before returning
the forward pass, default ``True``.
:param bool or Callable translation: Translating the final output before
returning the forward pass, default ``True``.
:param aggregator: The aggregator to be used to aggregate component-wise
partial results from the modules in ``networks``. Available
aggregators include: sum: ``+``, product: ``*``, mean: ``mean``,
min: ``min``, max: ``max``. Default is ``*``.
:type aggregator: str or Callable
:param reduction: The reduction to be used to reduce the aggregated
result of the modules in ``networks`` to the desired output
dimension. Available reductions include: sum: ``+``, product: ``*``,
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
:type reduction: str or Callable
:param bool scale: If ``True``, the final output is scaled before being
returned in the forward pass. Default is ``True``.
:param bool translation: If ``True``, the final output is translated
before being returned in the forward pass. Default is ``True``.
:raises ValueError: If the passed networks have not the same output
dimension.
.. warning::
In the forward pass we do not check if the input is instance of
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
A general rule is that for a :py:obj:`pina.label_tensor.LabelTensor`
input both list of integers and list of strings can be passed for
``input_indeces_branch_net``and ``input_indeces_trunk_net``.
Differently, for a :class:`torch.Tensor` only a list of integers
can be passed for ``input_indeces_branch_net``and
``input_indeces_trunk_net``.
No checks are performed in the forward pass to verify if the input
is instance of either :class:`~pina.label_tensor.LabelTensor` or
:class:`torch.Tensor`. In general, in case of a
:class:`~pina.label_tensor.LabelTensor`, both a ``list[int]`` or a
``list[str]`` can be passed as ``networks`` dict values.
Differently, in case of a :class:`torch.Tensor`, only a
``list[int]`` can be passed as ``networks`` dict values.
:Example:
>>> branch_net1 = FeedForward(input_dimensons=1,
@@ -162,6 +160,10 @@ class MIONet(torch.nn.Module):
"""
Return a dictionary of functions that can be used as aggregators or
reductions.
:param dict kwargs: Additional parameters.
:return: A dictionary of functions.
:rtype: dict
"""
return {
"+": partial(torch.sum, **kwargs),
@@ -172,6 +174,13 @@ class MIONet(torch.nn.Module):
}
def _init_aggregator(self, aggregator):
"""
Initialize the aggregator.
:param aggregator: The aggregator to be used to aggregate.
:type aggregator: str or Callable
:raises ValueError: If the aggregator is not supported.
"""
aggregator_funcs = self._symbol_functions(dim=2)
if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator]
@@ -184,6 +193,13 @@ class MIONet(torch.nn.Module):
self._aggregator_type = aggregator
def _init_reduction(self, reduction):
"""
Initialize the reduction.
:param reduction: The reduction to be used.
:type reduction: str or Callable
:raises ValueError: If the reduction is not supported.
"""
reduction_funcs = self._symbol_functions(dim=-1)
if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction]
@@ -196,6 +212,18 @@ class MIONet(torch.nn.Module):
self._reduction_type = reduction
def _get_vars(self, x, indeces):
"""
Extract the variables from the input tensor.
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:param indeces: The indeces to extract.
:type indeces: list[int] | list[str]
:raises RuntimeError: If failing to extract the variables.
:raises RuntimeError: If failing to extract the right indeces.
:return: The extracted variables.
:rtype: LabelTensor | torch.Tensor
"""
if isinstance(indeces[0], str):
try:
return x.extract(indeces)
@@ -216,12 +244,12 @@ class MIONet(torch.nn.Module):
def forward(self, x):
"""
Defines the computation performed at every call.
Forward pass for the :class:`MIONet` model.
:param LabelTensor or torch.Tensor x: The input tensor for the forward
call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor or torch.Tensor
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:return: The output tensor.
:rtype: LabelTensor | torch.Tensor
"""
# forward pass
@@ -248,13 +276,19 @@ class MIONet(torch.nn.Module):
def aggregator(self):
"""
The aggregator function.
:return: The aggregator function.
:rtype: str or Callable
"""
return self._aggregator
@property
def reduction(self):
"""
The translation factor.
The reduction function.
:return: The reduction function.
:rtype: str or Callable
"""
return self._reduction
@@ -262,13 +296,19 @@ class MIONet(torch.nn.Module):
def scale(self):
"""
The scale factor.
:return: The scale factor.
:rtype: torch.Tensor
"""
return self._scale
@property
def translation(self):
"""
The translation factor for MIONet.
The translation factor.
:return: The translation factor.
:rtype: torch.Tensor
"""
return self._trasl
@@ -276,6 +316,9 @@ class MIONet(torch.nn.Module):
def indeces_variables_extracted(self):
"""
The input indeces for each model in form of list.
:return: The indeces for each model.
:rtype: list
"""
return self._indeces
@@ -283,24 +326,27 @@ class MIONet(torch.nn.Module):
def model(self):
"""
The models in form of list.
:return: The models.
:rtype: list[torch.nn.Module]
"""
return self._indeces
class DeepONet(MIONet):
"""
The PINA implementation of DeepONet network.
DeepONet model class.
DeepONet is a general architecture for learning Operators. Unlike
traditional machine learning methods DeepONet is designed to map
entire functions to other functions. It can be trained both with
Physics Informed or Supervised learning strategies.
The MIONet is a general architecture for learning operators, which map
functions to functions. It can be trained with both Supervised and
Physics-Informed learning strategies.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
nonlinear operators via DeepONet based on the universal approximation
theorem of operator*. Nat Mach Intell 3, 218229 (2021).
**Original reference**: Lu, L., Jin, P., Pang, G. et al.
*Learning nonlinear operators via DeepONet based on the universal
approximation theorem of operator*.
Nat Mach Intell 3, 218-229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
@@ -318,42 +364,44 @@ class DeepONet(MIONet):
translation=True,
):
"""
Initialization of the :class:`DeepONet` class.
:param torch.nn.Module branch_net: The neural network to use as branch
model. It has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
The number of dimensions of the output has to be the same of the
``trunk_net``.
model. It has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
The output dimension has to be the same as that of ``trunk_net``.
:param torch.nn.Module trunk_net: The neural network to use as trunk
model. It has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
The number of dimensions of the output has to be the same of the
``branch_net``.
:param list(int) or list(str) input_indeces_branch_net: List of indeces
to extract from the input variable in the forward pass for the
branch net. If a list of ``int`` is passed, the corresponding
columns of the inner most entries are extracted. If a list of
``str`` is passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor` are extracted.
:param list(int) or list(str) input_indeces_trunk_net: List of indeces
to extract from the input variable in the forward pass for the
trunk net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is
passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor` are extracted.
:param str or Callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. Available aggregators include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``,
max: ``max``.
:param str or Callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. Available reductions include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``,
max: ``max``.
:param bool or Callable scale: Scaling the final output before returning
the forward pass, default True.
:param bool or Callable translation: Translating the final output before
returning the forward pass, default True.
model. It has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
The output dimension has to be the same as that of ``branch_net``.
:param input_indeces_branch_net: List of indeces to extract from the
input variable of the ``branch_net``.
If a list of ``int`` is passed, the corresponding columns of the
inner most entries are extracted. If a list of ``str`` is passed the
variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
:type input_indeces_branch_net: list[int] | list[str]
:param input_indeces_trunk_net: List of indeces to extract from the
input variable of the ``trunk_net``.
If a list of ``int`` is passed, the corresponding columns of the
inner most entries are extracted. If a list of ``str`` is passed the
variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
:type input_indeces_trunk_net: list[int] | list[str]
:param aggregator: The aggregator to be used to aggregate component-wise
partial results from the modules in ``networks``. Available
aggregators include: sum: ``+``, product: ``*``, mean: ``mean``,
min: ``min``, max: ``max``. Default is ``*``.
:type aggregator: str or Callable
:param reduction: The reduction to be used to reduce the aggregated
result of the modules in ``networks`` to the desired output
dimension. Available reductions include: sum: ``+``, product: ``*``,
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
:type reduction: str or Callable
:param bool scale: If ``True``, the final output is scaled before being
returned in the forward pass. Default is ``True``.
:param bool translation: If ``True``, the final output is translated
before being returned in the forward pass. Default is ``True``.
.. warning::
In the forward pass we do not check if the input is instance of
@@ -364,6 +412,14 @@ class DeepONet(MIONet):
Differently, for a :class:`torch.Tensor` only a list of integers can
be passed for ``input_indeces_branch_net`` and
``input_indeces_trunk_net``.
.. warning::
No checks are performed in the forward pass to verify if the input
is instance of either :class:`~pina.label_tensor.LabelTensor` or
:class:`torch.Tensor`. In general, in case of a
:class:`~pina.label_tensor.LabelTensor`, both a ``list[int]`` or a
``list[str]`` can be passed as ``input_indeces_branch_net`` and
``input_indeces_trunk_net``. Differently, in case of a
:class:`torch.Tensor`, only a ``list[int]`` can be passed.
:Example:
>>> branch_net = FeedForward(input_dimensons=1,
@@ -411,25 +467,31 @@ class DeepONet(MIONet):
def forward(self, x):
"""
Defines the computation performed at every call.
Forward pass for the :class:`DeepONet` model.
:param LabelTensor or torch.Tensor x: The input tensor for the forward
call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor or torch.Tensor
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:return: The output tensor.
:rtype: LabelTensor | torch.Tensor
"""
return super().forward(x)
@property
def branch_net(self):
"""
The branch net for DeepONet.
The branch net of the DeepONet.
:return: The branch net.
:rtype: torch.nn.Module
"""
return self.models[0]
@property
def trunk_net(self):
"""
The trunk net for DeepONet.
The trunk net of the DeepONet.
:return: The trunk net.
:rtype: torch.nn.Module
"""
return self.models[1]

View File

@@ -1,4 +1,4 @@
"""Module for FeedForward model"""
"""Module for the Feed Forward model class"""
import torch
from torch import nn
@@ -8,28 +8,8 @@ from .block.residual import EnhancedLinear
class FeedForward(torch.nn.Module):
"""
The PINA implementation of feedforward network, also refered as multilayer
perceptron.
:param int input_dimensions: The number of input components of the model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none, and :math:`d` the
``input_dimensions``.
:param int output_dimensions: The number of output components of the model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none, and :math:`d` the
``output_dimensions``.
:param int inner_size: number of neurons in the hidden layer(s). Default is
20.
:param int n_layers: number of hidden layers. Default is 2.
:param torch.nn.Module func: the activation function to use. If a single
:class:`torch.nn.Module` is passed, this is used as activation function
after any layers, except the last one. If a list of Modules is passed,
they are used as activation functions at any layers, in order.
:param list(int) | tuple(int) layers: a list containing the number of
neurons for any hidden layers. If specified, the parameters ``n_layers``
and ``inner_size`` are not considered.
:param bool bias: If ``True`` the MLP will consider some bias.
Feed Forward neural network model class, also known as Multi-layer
Perceptron.
"""
def __init__(
@@ -42,7 +22,36 @@ class FeedForward(torch.nn.Module):
layers=None,
bias=True,
):
""" """
"""
Initialization of the :class:`FeedForward` class.
:param int input_dimensions: The number of input components.
The expected tensor shape is :math:`(*, d)`, where *
represents any number of preceding dimensions (including none), and
:math:`d` corresponds to ``input_dimensions``.
:param int output_dimensions: The number of output components .
The expected tensor shape is :math:`(*, d)`, where *
represents any number of preceding dimensions (including none), and
:math:`d` corresponds to ``output_dimensions``.
:param int inner_size: The number of neurons for each hidden layer.
Default is ``20``.
:param int n_layers: The number of hidden layers. Default is ``2``.
::param func: The activation function. If a list is passed, it must have
the same length as ``n_layers``. If a single function is passed, it
is used for all layers, except for the last one.
Default is :class:`torch.nn.Tanh`.
:type func: torch.nn.Module | list[torch.nn.Module]
:param list[int] layers: The list of the dimension of inner layers.
If ``None``, ``n_layers`` of dimension ``inner_size`` are used.
Otherwise, it overrides the values passed to ``n_layers`` and
``inner_size``. Default is ``None``.
:param bool bias: If ``True`` bias is considered for the basis function
neural network. Default is ``True``.
:raises ValueError: If the input dimension is not an integer.
:raises ValueError: If the output dimension is not an integer.
:raises RuntimeError: If the number of layers and functions are
inconsistent.
"""
super().__init__()
if not isinstance(input_dimensions, int):
@@ -71,7 +80,7 @@ class FeedForward(torch.nn.Module):
self.functions = [func for _ in range(len(self.layers) - 1)]
if len(self.layers) != len(self.functions) + 1:
raise RuntimeError("uncosistent number of layers and functions")
raise RuntimeError("Incosistent number of layers and functions")
unique_list = []
for layer, func_ in zip(self.layers[:-1], self.functions):
@@ -84,52 +93,31 @@ class FeedForward(torch.nn.Module):
def forward(self, x):
"""
Defines the computation performed at every call.
Forward pass for the :class:`FeedForward` model.
:param x: The tensor to apply the forward pass.
:type x: torch.Tensor
:return: the output computed by the model.
:rtype: torch.Tensor
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor | LabelTensor
"""
return self.model(x)
class ResidualFeedForward(torch.nn.Module):
"""
The PINA implementation of feedforward network, also with skipped connection
and transformer network, as presented in **Understanding and mitigating
gradient pathologies in physics-informed neural networks**
Residual Feed Forward neural network model class.
The model is composed of a series of linear layers with a residual
connection between themm as presented in the following:
.. seealso::
**Original reference**: Wang, Sifan, Yujun Teng, and Paris Perdikaris.
**Original reference**: Wang, S., Teng, Y., and Perdikaris, P. (2021).
*Understanding and mitigating gradient flow pathologies in
physics-informed neural networks*. SIAM Journal on Scientific Computing
43.5 (2021): A3055-A3081.
physics-informed neural networks*.
SIAM Journal on Scientific Computing 43.5 (2021): A3055-A3081.
DOI: `10.1137/20M1318043
<https://epubs.siam.org/doi/abs/10.1137/20M1318043>`_
:param int input_dimensions: The number of input components of the model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none, and :math:`d` the
``input_dimensions``.
:param int output_dimensions: The number of output components of the model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none, and :math:`d` the
``output_dimensions``.
:param int inner_size: number of neurons in the hidden layer(s). Default is
20.
:param int n_layers: number of hidden layers. Default is 2.
:param torch.nn.Module func: the activation function to use. If a single
:class:`torch.nn.Module` is passed, this is used as activation function
after any layers, except the last one. If a list of Modules is passed,
they are used as activation functions at any layers, in order.
:param bool bias: If ``True`` the MLP will consider some bias.
:param list | tuple transformer_nets: a list or tuple containing the two
torch.nn.Module which act as transformer network. The input dimension
of the network must be the same as ``input_dimensions``, and the output
dimension must be the same as ``inner_size``.
"""
def __init__(
@@ -142,7 +130,37 @@ class ResidualFeedForward(torch.nn.Module):
bias=True,
transformer_nets=None,
):
""" """
"""
Initialization of the :class:`ResidualFeedForward` class.
:param int input_dimensions: The number of input components.
The expected tensor shape is :math:`(*, d)`, where *
represents any number of preceding dimensions (including none), and
:math:`d` corresponds to ``input_dimensions``.
:param int output_dimensions: The number of output components .
The expected tensor shape is :math:`(*, d)`, where *
represents any number of preceding dimensions (including none), and
:math:`d` corresponds to ``output_dimensions``.
:param int inner_size: The number of neurons for each hidden layer.
Default is ``20``.
:param int n_layers: The number of hidden layers. Default is ``2``.
::param func: The activation function. If a list is passed, it must have
the same length as ``n_layers``. If a single function is passed, it
is used for all layers, except for the last one.
Default is :class:`torch.nn.Tanh`.
:type func: torch.nn.Module | list[torch.nn.Module]
:param bool bias: If ``True`` bias is considered for the basis function
neural network. Default is ``True``.
:param transformer_nets: The two :class:`torch.nn.Module` acting as
transformer network. The input dimension of both networks must be
equal to ``input_dimensions``, and the output dimension must be
equal to ``inner_size``. If ``None``, two
:class:`~pina.model.block.residual.EnhancedLinear` layers are used.
Default is ``None``.
:type transformer_nets: list[torch.nn.Module] | tuple[torch.nn.Module]
:raises RuntimeError: If the number of layers and functions are
inconsistent.
"""
super().__init__()
# check type consistency
@@ -179,7 +197,7 @@ class ResidualFeedForward(torch.nn.Module):
self.functions = [func() for _ in range(len(self.layers))]
if len(self.layers) != len(self.functions):
raise RuntimeError("uncosistent number of layers and functions")
raise RuntimeError("Incosistent number of layers and functions")
unique_list = []
for layer, func_ in zip(self.layers, self.functions):
@@ -188,12 +206,12 @@ class ResidualFeedForward(torch.nn.Module):
def forward(self, x):
"""
Defines the computation performed at every call.
Forward pass for the :class:`ResidualFeedForward` model.
:param x: The tensor to apply the forward pass.
:type x: torch.Tensor
:return: the output computed by the model.
:rtype: torch.Tensor
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor | LabelTensor
"""
# enhance the input with transformer
input_ = []
@@ -210,6 +228,26 @@ class ResidualFeedForward(torch.nn.Module):
@staticmethod
def _check_transformer_nets(transformer_nets, input_dimensions, inner_size):
"""
Check the transformer networks consistency.
:param transformer_nets: The two :class:`torch.nn.Module` acting as
transformer network.
:type transformer_nets: list[torch.nn.Module] | tuple[torch.nn.Module]
:param int input_dimensions: The number of input components.
:param int inner_size: The number of neurons for each hidden layer.
:raises ValueError: If the passed ``transformer_nets`` is not a list of
length two.
:raises ValueError: If the passed ``transformer_nets`` is not a list of
:class:`torch.nn.Module`.
:raises ValueError: If the input dimension of the transformer network
is incompatible with the input dimension of the model.
:raises ValueError: If the output dimension of the transformer network
is incompatible with the inner size of the model.
:raises RuntimeError: If unexpected error occurs.
:return: The two :class:`torch.nn.Module` acting as transformer network.
:rtype: list[torch.nn.Module] | tuple[torch.nn.Module]
"""
# check transformer nets
if transformer_nets is None:
transformer_nets = [

View File

@@ -1,5 +1,5 @@
"""
Fourier Neural Operator Module.
Module for the Fourier Neural Operator model class.
"""
import warnings
@@ -13,18 +13,16 @@ from .kernel_neural_operator import KernelNeuralOperator
class FourierIntegralKernel(torch.nn.Module):
"""
Implementation of Fourier Integral Kernel network.
Fourier Integral Kernel model class.
This class implements the Fourier Integral Kernel network, which is a
PINA implementation of Fourier Neural Operator kernel network.
It performs global convolution by operating in the Fourier space.
This class implements the Fourier Integral Kernel network, which
performs global convolution in the Fourier space.
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli,
K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A.
(2020). *Fourier neural operator for parametric partial
differential equations*.
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu,
B., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020).
*Fourier neural operator for parametric partial differential equations*.
DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_
"""
@@ -43,16 +41,31 @@ class FourierIntegralKernel(torch.nn.Module):
layers=None,
):
"""
:param int input_numb_fields: Number of input fields.
:param int output_numb_fields: Number of output fields.
:param int | list[int] n_modes: Number of modes.
:param int dimensions: Number of dimensions (1, 2, or 3).
:param int padding: Padding size, defaults to 8.
:param str padding_type: Type of padding, defaults to "constant".
:param int inner_size: Inner size, defaults to 20.
:param int n_layers: Number of layers, defaults to 2.
:param torch.nn.Module func: Activation function, defaults to nn.Tanh.
:param list[int] layers: List of layer sizes, defaults to None.
Initialization of the :class:`FourierIntegralKernel` class.
:param int input_numb_fields: The number of input fields.
:param int output_numb_fields: The number of output fields.
:param n_modes: The number of modes.
:type n_modes: int | list[int]
:param int dimensions: The number of dimensions. It can be set to ``1``,
``2``, or ``3``. Default is ``3``.
:param int padding: The padding size. Default is ``8``.
:param str padding_type: The padding strategy. Default is ``constant``.
:param int inner_size: The inner size. Default is ``20``.
:param int n_layers: The number of layers. Default is ``2``.
:param func: The activation function. If a list is passed, it must have
the same length as ``n_layers``. If a single function is passed, it
is used for all layers, except for the last one.
Default is :class:`torch.nn.Tanh`.
:type func: torch.nn.Module | list[torch.nn.Module]
:param list[int] layers: The list of the dimension of inner layers.
If ``None``, ``n_layers`` of dimension ``inner_size`` are used.
Otherwise, it overrides the values passed to ``n_layers`` and
``inner_size``. Default is ``None``.
:raises RuntimeError: If the number of layers and functions are
inconsistent.
:raises RunTimeError: If the number of layers and modes are
inconsistent.
"""
super().__init__()
@@ -84,7 +97,7 @@ class FourierIntegralKernel(torch.nn.Module):
if isinstance(func, list):
if len(layers) != len(func):
raise RuntimeError(
"Uncosistent number of layers and functions."
"Inconsistent number of layers and functions."
)
_functions = func
else:
@@ -97,7 +110,7 @@ class FourierIntegralKernel(torch.nn.Module):
n_modes
):
raise RuntimeError(
"Uncosistent number of layers and functions."
"Inconsistent number of layers and modes."
)
if all(isinstance(i, int) for i in n_modes):
n_modes = [n_modes] * len(layers)
@@ -129,19 +142,17 @@ class FourierIntegralKernel(torch.nn.Module):
def forward(self, x):
"""
Forward computation for Fourier Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of Fourier Blocks are applied. Finally the output is projected
to the final dimensionality by the ``projecting_net``.
:param torch.Tensor x: The input tensor for fourier block,
depending on ``dimension`` in the initialization.
In particular it is expected:
Forward pass for the :class:`FourierIntegralKernel` model.
:param x: The input tensor for performing the computation. Depending
on the ``dimensions`` in the initialization, it expects a tensor
with the following shapes:
* 1D tensors: ``[batch, X, channels]``
* 2D tensors: ``[batch, X, Y, channels]``
* 3D tensors: ``[batch, X, Y, Z, channels]``
:return: The output tensor obtained from the kernels convolution.
:type x: torch.Tensor | LabelTensor
:raises Warning: If a LabelTensor is passed as input.
:return: The output tensor.
:rtype: torch.Tensor
"""
if isinstance(x, LabelTensor):
@@ -181,6 +192,22 @@ class FourierIntegralKernel(torch.nn.Module):
layers,
n_modes,
):
"""
Check the consistency of the input parameters.
:param int dimensions: The number of dimensions.
:param int padding: The padding size.
:param str padding_type: The padding strategy.
:param int inner_size: The inner size.
:param int n_layers: The number of layers.
:param func: The activation function.
:type func: torch.nn.Module | list[torch.nn.Module]
:param list[int] layers: The list of the dimension of inner layers.
:param n_modes: The number of modes.
:type n_modes: int | list[int]
:raises ValueError: If the input is not consistent.
"""
check_consistency(dimensions, int)
check_consistency(padding, int)
check_consistency(padding_type, str)
@@ -201,6 +228,15 @@ class FourierIntegralKernel(torch.nn.Module):
@staticmethod
def _get_fourier_block(dimensions):
"""
Retrieve the Fourier Block class based on the number of dimensions.
:param int dimensions: The number of dimensions.
:raises NotImplementedError: If the number of dimensions is not 1, 2,
or 3.
:return: The Fourier Block class.
:rtype: FourierBlock1D | FourierBlock2D | FourierBlock3D
"""
if dimensions == 1:
return FourierBlock1D
if dimensions == 2:
@@ -212,20 +248,18 @@ class FourierIntegralKernel(torch.nn.Module):
class FNO(KernelNeuralOperator):
"""
The PINA implementation of Fourier Neural Operator network.
Fourier Neural Operator model class.
Fourier Neural Operator (FNO) is a general architecture for
learning Operators. Unlike traditional machine learning methods
FNO is designed to map entire functions to other functions. It
can be trained with Supervised learning strategies. FNO does global
convolution by performing the operation on the Fourier space.
The Fourier Neural Operator (FNO) is a general architecture for learning
operators, which map functions to functions. It can be trained both with
Supervised and Physics_Informed learning strategies. The Fourier Neural
Operator performs global convolution in the Fourier space.
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli,
K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A.
(2020). *Fourier neural operator for parametric partial
differential equations*.
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K.,
Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020).
*Fourier neural operator for parametric partial differential equations*.
DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_
"""
@@ -244,18 +278,27 @@ class FNO(KernelNeuralOperator):
layers=None,
):
"""
:param torch.nn.Module lifting_net: The neural network for lifting
the input.
:param torch.nn.Module projecting_net: The neural network for
projecting the output.
:param int | list[int] n_modes: Number of modes.
:param int dimensions: Number of dimensions (1, 2, or 3).
:param int padding: Padding size, defaults to 8.
:param str padding_type: Type of padding, defaults to `constant`.
:param int inner_size: Inner size, defaults to 20.
:param int n_layers: Number of layers, defaults to 2.
:param torch.nn.Module func: Activation function, defaults to nn.Tanh.
:param list[int] layers: List of layer sizes, defaults to None.
param torch.nn.Module lifting_net: The lifting neural network mapping
the input to its hidden dimension.
:param torch.nn.Module projecting_net: The projection neural network
mapping the hidden representation to the output function.
:param n_modes: The number of modes.
:type n_modes: int | list[int]
:param int dimensions: The number of dimensions. It can be set to ``1``,
``2``, or ``3``. Default is ``3``.
:param int padding: The padding size. Default is ``8``.
:param str padding_type: The padding strategy. Default is ``constant``.
:param int inner_size: The inner size. Default is ``20``.
:param int n_layers: The number of layers. Default is ``2``.
:param func: The activation function. If a list is passed, it must have
the same length as ``n_layers``. If a single function is passed, it
is used for all layers, except for the last one.
Default is :class:`torch.nn.Tanh`.
:type func: torch.nn.Module | list[torch.nn.Module]
:param list[int] layers: The list of the dimension of inner layers.
If ``None``, ``n_layers`` of dimension ``inner_size`` are used.
Otherwise, it overrides the values passed to ``n_layers`` and
``inner_size``. Default is ``None``.
"""
lifting_operator_out = lifting_net(
torch.rand(size=next(lifting_net.parameters()).size())
@@ -279,19 +322,21 @@ class FNO(KernelNeuralOperator):
def forward(self, x):
"""
Forward computation for Fourier Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of Fourier Blocks are applied. Finally the output is projected
to the final dimensionality by the ``projecting_net``.
Forward pass for the :class:`FourierNeuralOperator` model.
:param torch.Tensor x: The input tensor for fourier block,
depending on ``dimension`` in the initialization. In
particular it is expected:
The ``lifting_net`` maps the input to the hidden dimension.
Then, several layers of Fourier blocks are applied. Finally, the
``projection_net`` maps the hidden representation to the output
function.
: param x: The input tensor for performing the computation. Depending
on the ``dimensions`` in the initialization, it expects a tensor
with the following shapes:
* 1D tensors: ``[batch, X, channels]``
* 2D tensors: ``[batch, X, Y, channels]``
* 3D tensors: ``[batch, X, Y, Z, channels]``
:return: The output tensor obtained from FNO.
:type x: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor
"""

View File

@@ -1,5 +1,5 @@
"""
Module for the Graph Neural Operator and Graph Neural Kernel.
Module for the Graph Neural Operator model class.
"""
import torch
@@ -10,7 +10,18 @@ from .kernel_neural_operator import KernelNeuralOperator
class GraphNeuralKernel(torch.nn.Module):
"""
TODO add docstring
Graph Neural Kernel model class.
This class implements the Graph Neural Kernel network.
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K.,
Liu, B., Bhattacharya, K., Stuart, A., Anandkumar, A. (2020).
*Neural Operator: Graph Kernel Network for Partial Differential
Equations*.
DOI: `arXiv preprint arXiv:2003.03485.
<https://arxiv.org/abs/2003.03485>`_
"""
def __init__(
@@ -26,28 +37,24 @@ class GraphNeuralKernel(torch.nn.Module):
shared_weights=False,
):
"""
The Graph Neural Kernel constructor.
Initialization of the :class:`GraphNeuralKernel` class.
:param width: The width of the kernel.
:type width: int
:param edge_features: The number of edge features.
:type edge_features: int
:param n_layers: The number of kernel layers.
:type n_layers: int
:param internal_n_layers: The number of layers the FF Neural Network
internal to each Kernel Layer.
:type internal_n_layers: int
:param internal_layers: Number of neurons of hidden layers(s) in the
FF Neural Network inside for each Kernel Layer.
:type internal_layers: list | tuple
:param internal_func: The activation function used inside the
computation of the representation of the edge features in the
Graph Integral Layer.
:param external_func: The activation function applied to the output of
the Graph Integral Layer.
:type external_func: torch.nn.Module
:param shared_weights: If ``True`` the weights of the Graph Integral
Layers are shared.
:param int width: The width of the kernel.
:param int edge_features: The number of edge features.
:param int n_layers: The number of kernel layers. Default is ``2``.
:param int internal_n_layers: The number of layers of the neural network
inside each kernel layer. Default is ``0``.
:param internal_layers: The number of neurons for each layer of the
neural network inside each kernel layer. Default is ``None``.
:type internal_layers: list[int] | tuple[int]
:param torch.nn.Module internal_func: The activation function used
inside each kernel layer. If ``None``, it uses the
:class:`torch.nn.Tanh`. activation. Default is ``None``.
:param torch.nn.Module external_func: The activation function applied to
the output of the each kernel layer. If ``None``, it uses the
:class:`torch.nn.Tanh`. activation. Default is ``None``.
:param bool shared_weights: If ``True``, the weights of each kernel
layer are shared. Default is ``False``.
"""
super().__init__()
if external_func is None:
@@ -85,11 +92,33 @@ class GraphNeuralKernel(torch.nn.Module):
self._forward_func = self._forward_unshared
def _forward_unshared(self, x, edge_index, edge_attr):
"""
Forward pass for the Graph Neural Kernel with unshared weights.
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge index.
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor
"""
for layer in self.layers:
x = layer(x, edge_index, edge_attr)
return x
def _forward_shared(self, x, edge_index, edge_attr):
"""
Forward pass for the Graph Neural Kernel with shared weights.
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge index.
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor
"""
for _ in range(self.n_layers):
x = self.layers(x, edge_index, edge_attr)
return x
@@ -98,19 +127,34 @@ class GraphNeuralKernel(torch.nn.Module):
"""
The forward pass of the Graph Neural Kernel.
:param x: The input batch.
:type x: torch.Tensor
:param edge_index: The edge index.
:type edge_index: torch.Tensor
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge index.
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor
:type edge_attr: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor
"""
return self._forward_func(x, edge_index, edge_attr)
class GraphNeuralOperator(KernelNeuralOperator):
"""
TODO add docstring
Graph Neural Operator model class.
The Graph Neural Operator is a general architecture for learning operators,
which map functions to functions. It can be trained both with Supervised
and Physics-Informed learning strategies. The Graph Neural Operator performs
graph convolution by means of a Graph Neural Kernel.
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K.,
Liu, B., Bhattacharya, K., Stuart, A., Anandkumar, A. (2020).
*Neural Operator: Graph Kernel Network for Partial Differential
Equations*.
DOI: `arXiv preprint arXiv:2003.03485.
<https://arxiv.org/abs/2003.03485>`_
"""
def __init__(
@@ -127,34 +171,29 @@ class GraphNeuralOperator(KernelNeuralOperator):
shared_weights=True,
):
"""
The Graph Neural Operator constructor.
Initialization of the :class:`GraphNeuralOperator` class.
:param lifting_operator: The lifting operator mapping the node features
to its hidden dimension.
:type lifting_operator: torch.nn.Module
:param projection_operator: The projection operator mapping the hidden
representation of the nodes features to the output function.
:type projection_operator: torch.nn.Module
:param edge_features: Number of edge features.
:type edge_features: int
:param n_layers: The number of kernel layers.
:type n_layers: int
:param internal_n_layers: The number of layers the Feed Forward Neural
Network internal to each Kernel Layer.
:type internal_n_layers: int
:param internal_layers: Number of neurons of hidden layers(s) in the
FF Neural Network inside for each Kernel Layer.
:type internal_layers: list | tuple
:param internal_func: The activation function used inside the
computation of the representation of the edge features in the
Graph Integral Layer.
:type internal_func: torch.nn.Module
:param external_func: The activation function applied to the output of
the Graph Integral Kernel.
:type external_func: torch.nn.Module
:param shared_weights: If ``True`` the weights of the Graph Integral
Layers are shared.
:type shared_weights: bool
param torch.nn.Module lifting_operator: The lifting neural network
mapping the input to its hidden dimension.
:param torch.nn.Module projection_operator: The projection neural
network mapping the hidden representation to the output function.
:param int edge_features: The number of edge features.
:param int n_layers: The number of kernel layers. Default is ``10``.
:param int internal_n_layers: The number of layers of the neural network
inside each kernel layer. Default is ``0``.
:param int inner_size: The size of the hidden layers of the neural
network inside each kernel layer. Default is ``None``.
:param internal_layers: The number of neurons for each layer of the
neural network inside each kernel layer. Default is ``None``.
:type internal_layers: list[int] | tuple[int]
:param torch.nn.Module internal_func: The activation function used
inside each kernel layer. If ``None``, it uses the
:class:`torch.nn.Tanh`. activation. Default is ``None``.
:param torch.nn.Module external_func: The activation function applied to
the output of the each kernel layer. If ``None``, it uses the
:class:`torch.nn.Tanh`. activation. Default is ``None``.
:param bool shared_weights: If ``True``, the weights of each kernel
layer are shared. Default is ``False``.
"""
if internal_func is None:
@@ -182,8 +221,9 @@ class GraphNeuralOperator(KernelNeuralOperator):
"""
The forward pass of the Graph Neural Operator.
:param x: The input batch.
:type x: torch_geometric.data.Batch
:param torch_geometric.data.Batch x: The input graph.
:return: The output tensor.
:rtype: torch.Tensor
"""
x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr
x = self.lifting_operator(x)

View File

@@ -1,5 +1,5 @@
"""
Kernel Neural Operator Module.
Module for the Kernel Neural Operator model class.
"""
import torch
@@ -8,13 +8,14 @@ from ..utils import check_consistency
class KernelNeuralOperator(torch.nn.Module):
r"""
Base class for composing Neural Operators with integral kernels.
Base class for Neural Operators with integral kernels.
This is a base class for composing neural operator with multiple
integral kernels. All neural operator models defined in PINA inherit
from this class. The structure is inspired by the work of Kovachki, N.
et al. see Figure 2 of the reference for extra details. The Neural
Operators inheriting from this class can be written as:
This class serves as a foundation for building Neural Operators that
incorporate multiple integral kernels. All Neural Operator models in
PINA inherit from this class. The design follows the framework proposed
by Kovachki et al., as illustrated in Figure 2 of their work.
Neural Operators derived from this class can be expressed as:
.. math::
G_\theta := P \circ K_m \circ \cdot \circ K_1 \circ L
@@ -40,15 +41,18 @@ class KernelNeuralOperator(torch.nn.Module):
**Original reference**: Kovachki, N., Li, Z., Liu, B.,
Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A.
(2023). *Neural operator: Learning maps between function
spaces with applications to PDEs*. Journal of Machine Learning
Research, 24(89), 1-97.
(2023).
*Neural operator: Learning maps between function spaces with
applications to PDEs*.
Journal of Machine Learning Research, 24(89), 1-97.
"""
def __init__(self, lifting_operator, integral_kernels, projection_operator):
"""
:param torch.nn.Module lifting_operator: The lifting operator
mapping the input to its hidden dimension.
Initialization of the :class:`KernelNeuralOperator` class.
:param torch.nn.Module lifting_operator: The lifting operator mapping
the input to its hidden dimension.
:param torch.nn.Module integral_kernels: List of integral kernels
mapping each hidden representation to the next one.
:param torch.nn.Module projection_operator: The projection operator
@@ -64,16 +68,19 @@ class KernelNeuralOperator(torch.nn.Module):
@property
def lifting_operator(self):
"""
The lifting operator property.
The lifting operator module.
:return: The lifting operator module.
:rtype: torch.nn.Module
"""
return self._lifting_operator
@lifting_operator.setter
def lifting_operator(self, value):
"""
The lifting operator setter
Set the lifting operator module.
:param torch.nn.Module value: The lifting operator torch module.
:param torch.nn.Module value: The lifting operator module.
"""
check_consistency(value, torch.nn.Module)
self._lifting_operator = value
@@ -81,16 +88,19 @@ class KernelNeuralOperator(torch.nn.Module):
@property
def projection_operator(self):
"""
The projection operator property.
The projection operator module.
:return: The projection operator module.
:rtype: torch.nn.Module
"""
return self._projection_operator
@projection_operator.setter
def projection_operator(self, value):
"""
The projection operator setter
Set the projection operator module.
:param torch.nn.Module value: The projection operator torch module.
:param torch.nn.Module value: The projection operator module.
"""
check_consistency(value, torch.nn.Module)
self._projection_operator = value
@@ -98,37 +108,41 @@ class KernelNeuralOperator(torch.nn.Module):
@property
def integral_kernels(self):
"""
The integral kernels operator property.
The integral kernels operator module.
:return: The integral kernels operator module.
:rtype: torch.nn.Module
"""
return self._integral_kernels
@integral_kernels.setter
def integral_kernels(self, value):
"""
The integral kernels operator setter
Set the integral kernels operator module.
:param torch.nn.Module value: The integral kernels operator torch
module.
:param torch.nn.Module value: The integral kernels operator module.
"""
check_consistency(value, torch.nn.Module)
self._integral_kernels = value
def forward(self, x):
r"""
Forward computation for Base Neural Operator. It performs a
lifting of the input by the ``lifting_operator``.
Then different layers integral kernels are applied using
``integral_kernels``. Finally the output is projected
to the final dimensionality by the ``projection_operator``.
Forward pass for the :class:`KernelNeuralOperator` model.
:param torch.Tensor x: The input tensor for performing the
computation. It expects a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem. In particular
:math:`D` is the number of spatial/paramtric/temporal variables
plus the field variables. For example for 2D problems with 2
output\ variables :math:`D=4`.
:return: The output tensor obtained from the NO.
The ``lifting_operator`` maps the input to the hidden dimension.
The ``integral_kernels`` apply the integral kernels to the hidden
representation. The ``projection_operator`` maps the hidden
representation to the output function.
:param x: The input tensor for performing the computation. It expects
a tensor :math:`B \times N \times D`, where :math:`B` is the
batch_size, :math:`N` the number of points in the mesh, and
:math:`D` the dimension of the problem. In particular, :math:`D`
is the number of spatial, parametric, and/or temporal variables
plus the field variables. For instance, for 2D problems with 2
output variables, :math:`D=4`.
:type x: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor
"""
x = self.lifting_operator(x)

View File

@@ -1,4 +1,4 @@
"""Module LowRank Neural Operator."""
"""Module for the Low Rank Neural Operator model class."""
import torch
from torch import nn
@@ -11,23 +11,20 @@ from .block.low_rank_block import LowRankBlock
class LowRankNeuralOperator(KernelNeuralOperator):
"""
Implementation of LowRank Neural Operator.
Low Rank Neural Operator model class.
LowRank Neural Operator is a general architecture for
learning Operators. Unlike traditional machine learning methods
LowRankNeuralOperator is designed to map entire functions
to other functions. It can be trained with Supervised or PINN based
learning strategies.
LowRankNeuralOperator does convolution by performing a low rank
approximation, see :class:`~pina.model.block.lowrank_layer.LowRankBlock`.
The Low Rank Neural Operator is a general architecture for learning
operators, which map functions to functions. It can be trained both with
Supervised and Physics-Informed learning strategies. The Low Rank Neural
Operator performs convolution by means of a low rank approximation.
.. seealso::
**Original reference**: Kovachki, N., Li, Z., Liu, B.,
Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A.
(2023). *Neural operator: Learning maps between function
spaces with applications to PDEs*. Journal of Machine Learning
Research, 24(89), 1-97.
**Original reference**: Kovachki, N., Li, Z., Liu, B., Azizzadenesheli,
K., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2023).
*Neural operator: Learning maps between function spaces with
applications to PDEs*.
Journal of Machine Learning Research, 24(89), 1-97.
"""
def __init__(
@@ -44,32 +41,35 @@ class LowRankNeuralOperator(KernelNeuralOperator):
bias=True,
):
"""
:param torch.nn.Module lifting_net: The neural network for lifting
the input. It must take as input the input field and the coordinates
at which the input field is avaluated. The output of the lifting
net is chosen as embedding dimension of the problem
:param torch.nn.Module projecting_net: The neural network for
projecting the output. It must take as input the embedding dimension
(output of the ``lifting_net``) plus the dimension
of the coordinates.
:param list[str] field_indices: the label of the fields
in the input tensor.
:param list[str] coordinates_indices: the label of the
coordinates in the input tensor.
:param int n_kernel_layers: number of hidden kernel layers.
Default is 4.
:param int inner_size: Number of neurons in the hidden layer(s) for the
basis function network. Default is 20.
:param int n_layers: Number of hidden layers. for the
basis function network. Default is 2.
:param func: The activation function to use for the
basis function network. If a single
:class:`torch.nn.Module` is passed, this is used as
activation function after any layers, except the last one.
If a list of Modules is passed,
they are used as activation functions at any layers, in order.
:param bool bias: If ``True`` the MLP will consider some bias for the
basis function network.
Initialization of the :class:`LowRankNeuralOperator` class.
:param torch.nn.Module lifting_net: The lifting neural network mapping
the input to its hidden dimension. It must take as input the input
field and the coordinates at which the input field is evaluated.
:param torch.nn.Module projecting_net: The projection neural network
mapping the hidden representation to the output function. It must
take as input the embedding dimension plus the dimension of the
coordinates.
:param list[str] field_indices: The labels of the fields in the input
tensor.
:param list[str] coordinates_indices: The labels of the coordinates in
the input tensor.
:param int n_kernel_layers: The number of hidden kernel layers.
:param int rank: The rank of the low rank approximation.
:param int inner_size: The number of neurons for each hidden layer in
the basis function neural network. Default is ``20``.
:param int n_layers: The number of hidden layers in the basis function
neural network. Default is ``2``.
:param func: The activation function. If a list is passed, it must have
the same length as ``n_layers``. If a single function is passed, it
is used for all layers, except for the last one.
Default is :class:`torch.nn.Tanh`.
:param bool bias: If ``True`` bias is considered for the basis function
neural network. Default is ``True``.
:raises ValueError: If the input dimension does not match with the
labels of the fields and coordinates.
:raises ValueError: If the input dimension of the projecting network
does not match with the hidden dimension of the lifting network.
"""
# check consistency
@@ -122,19 +122,20 @@ class LowRankNeuralOperator(KernelNeuralOperator):
def forward(self, x):
r"""
Forward computation for LowRank Neural Operator. It performs a
lifting of the input by the ``lifting_net``. Then different layers
of LowRank Neural Operator Blocks are applied.
Finally the output is projected to the final dimensionality
by the ``projecting_net``.
Forward pass for the :class:`LowRankNeuralOperator` model.
:param torch.Tensor x: The input tensor for fourier block,
depending on ``dimension`` in the initialization. It expects
a tensor :math:`B \times N \times D`,
where :math:`B` is the batch_size, :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem, i.e. the sum
of ``len(coordinates_indices)+len(field_indices)``.
:return: The output tensor obtained from Average Neural Operator.
The ``lifting_net`` maps the input to the hidden dimension.
Then, several layers of
:class:`~pina.model.block.low_rank_block.LowRankBlock` are
applied. Finally, the ``projecting_net`` maps the hidden representation
to the output function.
:param LabelTensor x: The input tensor for performing the computation.
It expects a tensor :math:`B \times N \times D`, where :math:`B` is
the batch_size, :math:`N` the number of points in the mesh,
:math:`D` the dimension of the problem, i.e. the sum
of ``len(coordinates_indices)`` and ``len(field_indices)``.
:return: The output tensor.
:rtype: torch.Tensor
"""
# extract points

View File

@@ -1,4 +1,4 @@
"""Module for Multi FeedForward model"""
"""Module for the Multi Feed Forward model class"""
from abc import ABC, abstractmethod
import torch
@@ -7,16 +7,21 @@ from .feed_forward import FeedForward
class MultiFeedForward(torch.nn.Module, ABC):
"""
The PINA implementation of MultiFeedForward network.
Multi Feed Forward neural network model class.
This model allows to create a network with multiple FeedForward combined
together. The user has to define the `forward` method choosing how to
combine the different FeedForward networks.
:param dict ffn_dict: dictionary of FeedForward networks.
This model allows to create a network with multiple Feed Forward neural
networks combined together. The user is required to define the ``forward``
method to choose how to combine the networks.
"""
def __init__(self, ffn_dict):
"""
Initialization of the :class:`MultiFeedForward` class.
:param dict ffn_dict: A dictionary containing the Feed Forward neural
networks to be combined.
:raises TypeError: If the input is not a dictionary.
"""
super().__init__()
if not isinstance(ffn_dict, dict):
@@ -28,5 +33,8 @@ class MultiFeedForward(torch.nn.Module, ABC):
@abstractmethod
def forward(self, *args, **kwargs):
"""
TODO: Docstring
Forward pass for the :class:`MultiFeedForward` model.
The user is required to define this method to choose how to combine the
networks.
"""

View File

@@ -1,19 +1,26 @@
"""Module for Spline model"""
"""Module for the Spline model class"""
import torch
from ..utils import check_consistency
class Spline(torch.nn.Module):
"""TODO: Docstring for Spline."""
"""
Spline model class.
"""
def __init__(self, order=4, knots=None, control_points=None) -> None:
"""
Spline model.
Initialization of the :class:`Spline` class.
:param int order: the order of the spline.
:param torch.Tensor knots: the knot vector.
:param torch.Tensor control_points: the control points.
:param int order: The order of the spline. Default is ``4``.
:param torch.Tensor knots: The tensor representing knots. If ``None``,
the knots will be initialized automatically. Default is ``None``.
:param torch.Tensor control_points: The control points. Default is
``None``.
:raises ValueError: If the order is negative.
:raises ValueError: If both knots and control points are ``None``.
:raises ValueError: If the knot tensor is not one-dimensional.
"""
super().__init__()
@@ -63,13 +70,13 @@ class Spline(torch.nn.Module):
def basis(self, x, k, i, t):
"""
Recursive function to compute the basis functions of the spline.
Recursive method to compute the basis functions of the spline.
:param torch.Tensor x: points to be evaluated.
:param int k: spline degree
:param int i: the index of the interval
:param torch.Tensor t: vector of knots
:return: the basis functions evaluated at x
:param torch.Tensor x: The points to be evaluated.
:param int k: The spline degree.
:param int i: The index of the interval.
:param torch.Tensor t: The tensor of knots.
:return: The basis functions evaluated at x
:rtype: torch.Tensor
"""
@@ -100,11 +107,23 @@ class Spline(torch.nn.Module):
@property
def control_points(self):
"""TODO: Docstring for control_points."""
"""
The control points of the spline.
:return: The control points.
:rtype: torch.Tensor
"""
return self._control_points
@control_points.setter
def control_points(self, value):
"""
Set the control points of the spline.
:param value: The control points.
:type value: torch.Tensor | dict
:raises ValueError: If invalid value is passed.
"""
if isinstance(value, dict):
if "n" not in value:
raise ValueError("Invalid value for control_points")
@@ -118,11 +137,23 @@ class Spline(torch.nn.Module):
@property
def knots(self):
"""TODO: Docstring for knots."""
"""
The knots of the spline.
:return: The knots.
:rtype: torch.Tensor
"""
return self._knots
@knots.setter
def knots(self, value):
"""
Set the knots of the spline.
:param value: The knots.
:type value: torch.Tensor | dict
:raises ValueError: If invalid value is passed.
"""
if isinstance(value, dict):
type_ = value.get("type", "auto")
@@ -152,10 +183,10 @@ class Spline(torch.nn.Module):
def forward(self, x):
"""
Forward pass of the spline model.
Forward pass for the :class:`Spline` model.
:param torch.Tensor x: points to be evaluated.
:return: the spline evaluated at x
:param torch.Tensor x: The input tensor.
:return: The output tensor.
:rtype: torch.Tensor
"""
t = self.knots