fix doc model part 1

This commit is contained in:
giovanni
2025-03-14 12:24:27 +01:00
committed by Nicola Demo
parent def8f5a1d3
commit 8dc682c849
10 changed files with 676 additions and 433 deletions

View File

@@ -1,4 +1,4 @@
"""Module for DeepONet model"""
"""Module for the DeepONet and MIONet model classes"""
from functools import partial
import torch
@@ -8,22 +8,18 @@ from ..utils import check_consistency, is_function
class MIONet(torch.nn.Module):
"""
The PINA implementation of MIONet network.
MIONet model class.
MIONet is a general architecture for learning Operators defined
on the tensor product of Banach spaces. Unlike traditional machine
learning methods MIONet is designed to map entire functions to other
functions. It can be trained both with Physics Informed or Supervised
learning strategies.
The MIONet is a general architecture for learning operators, which map
functions to functions. It can be trained with both Supervised and
Physics-Informed learning strategies.
.. seealso::
**Original reference**: Jin, Pengzhan, Shuai Meng, and Lu Lu.
**Original reference**: Jin, P., Meng, S., and Lu L. (2022).
*MIONet: Learning multiple-input operators via tensor product.*
SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351
DOI: `10.1137/22M1477751
<https://doi.org/10.1137/22M1477751>`_
DOI: `10.1137/22M1477751 <https://doi.org/10.1137/22M1477751>`_
"""
def __init__(
@@ -35,42 +31,44 @@ class MIONet(torch.nn.Module):
translation=True,
):
"""
:param dict networks: The neural networks to use as
models. The ``dict`` takes as key a neural network, and
as value the list of indeces to extract from the input variable
in the forward pass of the neural network. If a list of ``int``
is passed, the corresponding columns of the inner most entries are
extracted.
If a list of ``str`` is passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor`are extracted. The
``torch.nn.Module`` model has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
Default implementation consist of different branch nets and one
Initialization of the :class:`MIONet` class.
:param dict networks: The neural networks to use as models. The ``dict``
takes as key a neural network, and as value the list of indeces to
extract from the input variable in the forward pass of the neural
network. If a ``list[int]`` is passed, the corresponding columns of
the inner most entries are extracted. If a ``list[str]`` is passed
the variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
Each :class:`torch.nn.Module` model has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
Default implementation consists of several branch nets and one
trunk nets.
:param str or Callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. Available aggregators include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max:
``max``.
:param str or Callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. Available reductions include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max:
``max``.
:param bool or Callable scale: Scaling the final output before returning
the forward pass, default ``True``.
:param bool or Callable translation: Translating the final output before
returning the forward pass, default ``True``.
:param aggregator: The aggregator to be used to aggregate component-wise
partial results from the modules in ``networks``. Available
aggregators include: sum: ``+``, product: ``*``, mean: ``mean``,
min: ``min``, max: ``max``. Default is ``*``.
:type aggregator: str or Callable
:param reduction: The reduction to be used to reduce the aggregated
result of the modules in ``networks`` to the desired output
dimension. Available reductions include: sum: ``+``, product: ``*``,
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
:type reduction: str or Callable
:param bool scale: If ``True``, the final output is scaled before being
returned in the forward pass. Default is ``True``.
:param bool translation: If ``True``, the final output is translated
before being returned in the forward pass. Default is ``True``.
:raises ValueError: If the passed networks have not the same output
dimension.
.. warning::
In the forward pass we do not check if the input is instance of
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
A general rule is that for a :py:obj:`pina.label_tensor.LabelTensor`
input both list of integers and list of strings can be passed for
``input_indeces_branch_net``and ``input_indeces_trunk_net``.
Differently, for a :class:`torch.Tensor` only a list of integers
can be passed for ``input_indeces_branch_net``and
``input_indeces_trunk_net``.
No checks are performed in the forward pass to verify if the input
is instance of either :class:`~pina.label_tensor.LabelTensor` or
:class:`torch.Tensor`. In general, in case of a
:class:`~pina.label_tensor.LabelTensor`, both a ``list[int]`` or a
``list[str]`` can be passed as ``networks`` dict values.
Differently, in case of a :class:`torch.Tensor`, only a
``list[int]`` can be passed as ``networks`` dict values.
:Example:
>>> branch_net1 = FeedForward(input_dimensons=1,
@@ -162,6 +160,10 @@ class MIONet(torch.nn.Module):
"""
Return a dictionary of functions that can be used as aggregators or
reductions.
:param dict kwargs: Additional parameters.
:return: A dictionary of functions.
:rtype: dict
"""
return {
"+": partial(torch.sum, **kwargs),
@@ -172,6 +174,13 @@ class MIONet(torch.nn.Module):
}
def _init_aggregator(self, aggregator):
"""
Initialize the aggregator.
:param aggregator: The aggregator to be used to aggregate.
:type aggregator: str or Callable
:raises ValueError: If the aggregator is not supported.
"""
aggregator_funcs = self._symbol_functions(dim=2)
if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator]
@@ -184,6 +193,13 @@ class MIONet(torch.nn.Module):
self._aggregator_type = aggregator
def _init_reduction(self, reduction):
"""
Initialize the reduction.
:param reduction: The reduction to be used.
:type reduction: str or Callable
:raises ValueError: If the reduction is not supported.
"""
reduction_funcs = self._symbol_functions(dim=-1)
if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction]
@@ -196,6 +212,18 @@ class MIONet(torch.nn.Module):
self._reduction_type = reduction
def _get_vars(self, x, indeces):
"""
Extract the variables from the input tensor.
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:param indeces: The indeces to extract.
:type indeces: list[int] | list[str]
:raises RuntimeError: If failing to extract the variables.
:raises RuntimeError: If failing to extract the right indeces.
:return: The extracted variables.
:rtype: LabelTensor | torch.Tensor
"""
if isinstance(indeces[0], str):
try:
return x.extract(indeces)
@@ -216,12 +244,12 @@ class MIONet(torch.nn.Module):
def forward(self, x):
"""
Defines the computation performed at every call.
Forward pass for the :class:`MIONet` model.
:param LabelTensor or torch.Tensor x: The input tensor for the forward
call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor or torch.Tensor
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:return: The output tensor.
:rtype: LabelTensor | torch.Tensor
"""
# forward pass
@@ -248,13 +276,19 @@ class MIONet(torch.nn.Module):
def aggregator(self):
"""
The aggregator function.
:return: The aggregator function.
:rtype: str or Callable
"""
return self._aggregator
@property
def reduction(self):
"""
The translation factor.
The reduction function.
:return: The reduction function.
:rtype: str or Callable
"""
return self._reduction
@@ -262,13 +296,19 @@ class MIONet(torch.nn.Module):
def scale(self):
"""
The scale factor.
:return: The scale factor.
:rtype: torch.Tensor
"""
return self._scale
@property
def translation(self):
"""
The translation factor for MIONet.
The translation factor.
:return: The translation factor.
:rtype: torch.Tensor
"""
return self._trasl
@@ -276,6 +316,9 @@ class MIONet(torch.nn.Module):
def indeces_variables_extracted(self):
"""
The input indeces for each model in form of list.
:return: The indeces for each model.
:rtype: list
"""
return self._indeces
@@ -283,24 +326,27 @@ class MIONet(torch.nn.Module):
def model(self):
"""
The models in form of list.
:return: The models.
:rtype: list[torch.nn.Module]
"""
return self._indeces
class DeepONet(MIONet):
"""
The PINA implementation of DeepONet network.
DeepONet model class.
DeepONet is a general architecture for learning Operators. Unlike
traditional machine learning methods DeepONet is designed to map
entire functions to other functions. It can be trained both with
Physics Informed or Supervised learning strategies.
The MIONet is a general architecture for learning operators, which map
functions to functions. It can be trained with both Supervised and
Physics-Informed learning strategies.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
nonlinear operators via DeepONet based on the universal approximation
theorem of operator*. Nat Mach Intell 3, 218229 (2021).
**Original reference**: Lu, L., Jin, P., Pang, G. et al.
*Learning nonlinear operators via DeepONet based on the universal
approximation theorem of operator*.
Nat Mach Intell 3, 218-229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
@@ -318,42 +364,44 @@ class DeepONet(MIONet):
translation=True,
):
"""
Initialization of the :class:`DeepONet` class.
:param torch.nn.Module branch_net: The neural network to use as branch
model. It has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
The number of dimensions of the output has to be the same of the
``trunk_net``.
model. It has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
The output dimension has to be the same as that of ``trunk_net``.
:param torch.nn.Module trunk_net: The neural network to use as trunk
model. It has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
The number of dimensions of the output has to be the same of the
``branch_net``.
:param list(int) or list(str) input_indeces_branch_net: List of indeces
to extract from the input variable in the forward pass for the
branch net. If a list of ``int`` is passed, the corresponding
columns of the inner most entries are extracted. If a list of
``str`` is passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor` are extracted.
:param list(int) or list(str) input_indeces_trunk_net: List of indeces
to extract from the input variable in the forward pass for the
trunk net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is
passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor` are extracted.
:param str or Callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. Available aggregators include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``,
max: ``max``.
:param str or Callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. Available reductions include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``,
max: ``max``.
:param bool or Callable scale: Scaling the final output before returning
the forward pass, default True.
:param bool or Callable translation: Translating the final output before
returning the forward pass, default True.
model. It has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
The output dimension has to be the same as that of ``branch_net``.
:param input_indeces_branch_net: List of indeces to extract from the
input variable of the ``branch_net``.
If a list of ``int`` is passed, the corresponding columns of the
inner most entries are extracted. If a list of ``str`` is passed the
variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
:type input_indeces_branch_net: list[int] | list[str]
:param input_indeces_trunk_net: List of indeces to extract from the
input variable of the ``trunk_net``.
If a list of ``int`` is passed, the corresponding columns of the
inner most entries are extracted. If a list of ``str`` is passed the
variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
:type input_indeces_trunk_net: list[int] | list[str]
:param aggregator: The aggregator to be used to aggregate component-wise
partial results from the modules in ``networks``. Available
aggregators include: sum: ``+``, product: ``*``, mean: ``mean``,
min: ``min``, max: ``max``. Default is ``*``.
:type aggregator: str or Callable
:param reduction: The reduction to be used to reduce the aggregated
result of the modules in ``networks`` to the desired output
dimension. Available reductions include: sum: ``+``, product: ``*``,
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
:type reduction: str or Callable
:param bool scale: If ``True``, the final output is scaled before being
returned in the forward pass. Default is ``True``.
:param bool translation: If ``True``, the final output is translated
before being returned in the forward pass. Default is ``True``.
.. warning::
In the forward pass we do not check if the input is instance of
@@ -364,6 +412,14 @@ class DeepONet(MIONet):
Differently, for a :class:`torch.Tensor` only a list of integers can
be passed for ``input_indeces_branch_net`` and
``input_indeces_trunk_net``.
.. warning::
No checks are performed in the forward pass to verify if the input
is instance of either :class:`~pina.label_tensor.LabelTensor` or
:class:`torch.Tensor`. In general, in case of a
:class:`~pina.label_tensor.LabelTensor`, both a ``list[int]`` or a
``list[str]`` can be passed as ``input_indeces_branch_net`` and
``input_indeces_trunk_net``. Differently, in case of a
:class:`torch.Tensor`, only a ``list[int]`` can be passed.
:Example:
>>> branch_net = FeedForward(input_dimensons=1,
@@ -411,25 +467,31 @@ class DeepONet(MIONet):
def forward(self, x):
"""
Defines the computation performed at every call.
Forward pass for the :class:`DeepONet` model.
:param LabelTensor or torch.Tensor x: The input tensor for the forward
call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor or torch.Tensor
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:return: The output tensor.
:rtype: LabelTensor | torch.Tensor
"""
return super().forward(x)
@property
def branch_net(self):
"""
The branch net for DeepONet.
The branch net of the DeepONet.
:return: The branch net.
:rtype: torch.nn.Module
"""
return self.models[0]
@property
def trunk_net(self):
"""
The trunk net for DeepONet.
The trunk net of the DeepONet.
:return: The trunk net.
:rtype: torch.nn.Module
"""
return self.models[1]