fix doc model part 1
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Module for DeepONet model"""
|
||||
"""Module for the DeepONet and MIONet model classes"""
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
@@ -8,22 +8,18 @@ from ..utils import check_consistency, is_function
|
||||
|
||||
class MIONet(torch.nn.Module):
|
||||
"""
|
||||
The PINA implementation of MIONet network.
|
||||
MIONet model class.
|
||||
|
||||
MIONet is a general architecture for learning Operators defined
|
||||
on the tensor product of Banach spaces. Unlike traditional machine
|
||||
learning methods MIONet is designed to map entire functions to other
|
||||
functions. It can be trained both with Physics Informed or Supervised
|
||||
learning strategies.
|
||||
The MIONet is a general architecture for learning operators, which map
|
||||
functions to functions. It can be trained with both Supervised and
|
||||
Physics-Informed learning strategies.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Jin, Pengzhan, Shuai Meng, and Lu Lu.
|
||||
**Original reference**: Jin, P., Meng, S., and Lu L. (2022).
|
||||
*MIONet: Learning multiple-input operators via tensor product.*
|
||||
SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351
|
||||
DOI: `10.1137/22M1477751
|
||||
<https://doi.org/10.1137/22M1477751>`_
|
||||
|
||||
DOI: `10.1137/22M1477751 <https://doi.org/10.1137/22M1477751>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,42 +31,44 @@ class MIONet(torch.nn.Module):
|
||||
translation=True,
|
||||
):
|
||||
"""
|
||||
:param dict networks: The neural networks to use as
|
||||
models. The ``dict`` takes as key a neural network, and
|
||||
as value the list of indeces to extract from the input variable
|
||||
in the forward pass of the neural network. If a list of ``int``
|
||||
is passed, the corresponding columns of the inner most entries are
|
||||
extracted.
|
||||
If a list of ``str`` is passed the variables of the corresponding
|
||||
:py:obj:`pina.label_tensor.LabelTensor`are extracted. The
|
||||
``torch.nn.Module`` model has to take as input a
|
||||
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
|
||||
Default implementation consist of different branch nets and one
|
||||
Initialization of the :class:`MIONet` class.
|
||||
|
||||
:param dict networks: The neural networks to use as models. The ``dict``
|
||||
takes as key a neural network, and as value the list of indeces to
|
||||
extract from the input variable in the forward pass of the neural
|
||||
network. If a ``list[int]`` is passed, the corresponding columns of
|
||||
the inner most entries are extracted. If a ``list[str]`` is passed
|
||||
the variables of the corresponding
|
||||
:class:`~pina.label_tensor.LabelTensor` are extracted.
|
||||
Each :class:`torch.nn.Module` model has to take as input either a
|
||||
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
|
||||
Default implementation consists of several branch nets and one
|
||||
trunk nets.
|
||||
:param str or Callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. Available aggregators include
|
||||
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max:
|
||||
``max``.
|
||||
:param str or Callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. Available reductions include
|
||||
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max:
|
||||
``max``.
|
||||
:param bool or Callable scale: Scaling the final output before returning
|
||||
the forward pass, default ``True``.
|
||||
:param bool or Callable translation: Translating the final output before
|
||||
returning the forward pass, default ``True``.
|
||||
:param aggregator: The aggregator to be used to aggregate component-wise
|
||||
partial results from the modules in ``networks``. Available
|
||||
aggregators include: sum: ``+``, product: ``*``, mean: ``mean``,
|
||||
min: ``min``, max: ``max``. Default is ``*``.
|
||||
:type aggregator: str or Callable
|
||||
:param reduction: The reduction to be used to reduce the aggregated
|
||||
result of the modules in ``networks`` to the desired output
|
||||
dimension. Available reductions include: sum: ``+``, product: ``*``,
|
||||
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
|
||||
:type reduction: str or Callable
|
||||
:param bool scale: If ``True``, the final output is scaled before being
|
||||
returned in the forward pass. Default is ``True``.
|
||||
:param bool translation: If ``True``, the final output is translated
|
||||
before being returned in the forward pass. Default is ``True``.
|
||||
:raises ValueError: If the passed networks have not the same output
|
||||
dimension.
|
||||
|
||||
.. warning::
|
||||
In the forward pass we do not check if the input is instance of
|
||||
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
|
||||
A general rule is that for a :py:obj:`pina.label_tensor.LabelTensor`
|
||||
input both list of integers and list of strings can be passed for
|
||||
``input_indeces_branch_net``and ``input_indeces_trunk_net``.
|
||||
Differently, for a :class:`torch.Tensor` only a list of integers
|
||||
can be passed for ``input_indeces_branch_net``and
|
||||
``input_indeces_trunk_net``.
|
||||
No checks are performed in the forward pass to verify if the input
|
||||
is instance of either :class:`~pina.label_tensor.LabelTensor` or
|
||||
:class:`torch.Tensor`. In general, in case of a
|
||||
:class:`~pina.label_tensor.LabelTensor`, both a ``list[int]`` or a
|
||||
``list[str]`` can be passed as ``networks`` dict values.
|
||||
Differently, in case of a :class:`torch.Tensor`, only a
|
||||
``list[int]`` can be passed as ``networks`` dict values.
|
||||
|
||||
:Example:
|
||||
>>> branch_net1 = FeedForward(input_dimensons=1,
|
||||
@@ -162,6 +160,10 @@ class MIONet(torch.nn.Module):
|
||||
"""
|
||||
Return a dictionary of functions that can be used as aggregators or
|
||||
reductions.
|
||||
|
||||
:param dict kwargs: Additional parameters.
|
||||
:return: A dictionary of functions.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {
|
||||
"+": partial(torch.sum, **kwargs),
|
||||
@@ -172,6 +174,13 @@ class MIONet(torch.nn.Module):
|
||||
}
|
||||
|
||||
def _init_aggregator(self, aggregator):
|
||||
"""
|
||||
Initialize the aggregator.
|
||||
|
||||
:param aggregator: The aggregator to be used to aggregate.
|
||||
:type aggregator: str or Callable
|
||||
:raises ValueError: If the aggregator is not supported.
|
||||
"""
|
||||
aggregator_funcs = self._symbol_functions(dim=2)
|
||||
if aggregator in aggregator_funcs:
|
||||
aggregator_func = aggregator_funcs[aggregator]
|
||||
@@ -184,6 +193,13 @@ class MIONet(torch.nn.Module):
|
||||
self._aggregator_type = aggregator
|
||||
|
||||
def _init_reduction(self, reduction):
|
||||
"""
|
||||
Initialize the reduction.
|
||||
|
||||
:param reduction: The reduction to be used.
|
||||
:type reduction: str or Callable
|
||||
:raises ValueError: If the reduction is not supported.
|
||||
"""
|
||||
reduction_funcs = self._symbol_functions(dim=-1)
|
||||
if reduction in reduction_funcs:
|
||||
reduction_func = reduction_funcs[reduction]
|
||||
@@ -196,6 +212,18 @@ class MIONet(torch.nn.Module):
|
||||
self._reduction_type = reduction
|
||||
|
||||
def _get_vars(self, x, indeces):
|
||||
"""
|
||||
Extract the variables from the input tensor.
|
||||
|
||||
:param x: The input tensor.
|
||||
:type x: LabelTensor | torch.Tensor
|
||||
:param indeces: The indeces to extract.
|
||||
:type indeces: list[int] | list[str]
|
||||
:raises RuntimeError: If failing to extract the variables.
|
||||
:raises RuntimeError: If failing to extract the right indeces.
|
||||
:return: The extracted variables.
|
||||
:rtype: LabelTensor | torch.Tensor
|
||||
"""
|
||||
if isinstance(indeces[0], str):
|
||||
try:
|
||||
return x.extract(indeces)
|
||||
@@ -216,12 +244,12 @@ class MIONet(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
Forward pass for the :class:`MIONet` model.
|
||||
|
||||
:param LabelTensor or torch.Tensor x: The input tensor for the forward
|
||||
call.
|
||||
:return: The output computed by the DeepONet model.
|
||||
:rtype: LabelTensor or torch.Tensor
|
||||
:param x: The input tensor.
|
||||
:type x: LabelTensor | torch.Tensor
|
||||
:return: The output tensor.
|
||||
:rtype: LabelTensor | torch.Tensor
|
||||
"""
|
||||
|
||||
# forward pass
|
||||
@@ -248,13 +276,19 @@ class MIONet(torch.nn.Module):
|
||||
def aggregator(self):
|
||||
"""
|
||||
The aggregator function.
|
||||
|
||||
:return: The aggregator function.
|
||||
:rtype: str or Callable
|
||||
"""
|
||||
return self._aggregator
|
||||
|
||||
@property
|
||||
def reduction(self):
|
||||
"""
|
||||
The translation factor.
|
||||
The reduction function.
|
||||
|
||||
:return: The reduction function.
|
||||
:rtype: str or Callable
|
||||
"""
|
||||
return self._reduction
|
||||
|
||||
@@ -262,13 +296,19 @@ class MIONet(torch.nn.Module):
|
||||
def scale(self):
|
||||
"""
|
||||
The scale factor.
|
||||
|
||||
:return: The scale factor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def translation(self):
|
||||
"""
|
||||
The translation factor for MIONet.
|
||||
The translation factor.
|
||||
|
||||
:return: The translation factor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._trasl
|
||||
|
||||
@@ -276,6 +316,9 @@ class MIONet(torch.nn.Module):
|
||||
def indeces_variables_extracted(self):
|
||||
"""
|
||||
The input indeces for each model in form of list.
|
||||
|
||||
:return: The indeces for each model.
|
||||
:rtype: list
|
||||
"""
|
||||
return self._indeces
|
||||
|
||||
@@ -283,24 +326,27 @@ class MIONet(torch.nn.Module):
|
||||
def model(self):
|
||||
"""
|
||||
The models in form of list.
|
||||
|
||||
:return: The models.
|
||||
:rtype: list[torch.nn.Module]
|
||||
"""
|
||||
return self._indeces
|
||||
|
||||
|
||||
class DeepONet(MIONet):
|
||||
"""
|
||||
The PINA implementation of DeepONet network.
|
||||
DeepONet model class.
|
||||
|
||||
DeepONet is a general architecture for learning Operators. Unlike
|
||||
traditional machine learning methods DeepONet is designed to map
|
||||
entire functions to other functions. It can be trained both with
|
||||
Physics Informed or Supervised learning strategies.
|
||||
The MIONet is a general architecture for learning operators, which map
|
||||
functions to functions. It can be trained with both Supervised and
|
||||
Physics-Informed learning strategies.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
||||
nonlinear operators via DeepONet based on the universal approximation
|
||||
theorem of operator*. Nat Mach Intell 3, 218–229 (2021).
|
||||
**Original reference**: Lu, L., Jin, P., Pang, G. et al.
|
||||
*Learning nonlinear operators via DeepONet based on the universal
|
||||
approximation theorem of operator*.
|
||||
Nat Mach Intell 3, 218-229 (2021).
|
||||
DOI: `10.1038/s42256-021-00302-5
|
||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||
|
||||
@@ -318,42 +364,44 @@ class DeepONet(MIONet):
|
||||
translation=True,
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`DeepONet` class.
|
||||
|
||||
:param torch.nn.Module branch_net: The neural network to use as branch
|
||||
model. It has to take as input a
|
||||
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
|
||||
The number of dimensions of the output has to be the same of the
|
||||
``trunk_net``.
|
||||
model. It has to take as input either a
|
||||
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
|
||||
The output dimension has to be the same as that of ``trunk_net``.
|
||||
:param torch.nn.Module trunk_net: The neural network to use as trunk
|
||||
model. It has to take as input a
|
||||
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
|
||||
The number of dimensions of the output has to be the same of the
|
||||
``branch_net``.
|
||||
:param list(int) or list(str) input_indeces_branch_net: List of indeces
|
||||
to extract from the input variable in the forward pass for the
|
||||
branch net. If a list of ``int`` is passed, the corresponding
|
||||
columns of the inner most entries are extracted. If a list of
|
||||
``str`` is passed the variables of the corresponding
|
||||
:py:obj:`pina.label_tensor.LabelTensor` are extracted.
|
||||
:param list(int) or list(str) input_indeces_trunk_net: List of indeces
|
||||
to extract from the input variable in the forward pass for the
|
||||
trunk net. If a list of ``int`` is passed, the corresponding columns
|
||||
of the inner most entries are extracted. If a list of ``str`` is
|
||||
passed the variables of the corresponding
|
||||
:py:obj:`pina.label_tensor.LabelTensor` are extracted.
|
||||
:param str or Callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. Available aggregators include
|
||||
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``,
|
||||
max: ``max``.
|
||||
:param str or Callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. Available reductions include
|
||||
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``,
|
||||
max: ``max``.
|
||||
:param bool or Callable scale: Scaling the final output before returning
|
||||
the forward pass, default True.
|
||||
:param bool or Callable translation: Translating the final output before
|
||||
returning the forward pass, default True.
|
||||
model. It has to take as input either a
|
||||
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
|
||||
The output dimension has to be the same as that of ``branch_net``.
|
||||
:param input_indeces_branch_net: List of indeces to extract from the
|
||||
input variable of the ``branch_net``.
|
||||
If a list of ``int`` is passed, the corresponding columns of the
|
||||
inner most entries are extracted. If a list of ``str`` is passed the
|
||||
variables of the corresponding
|
||||
:class:`~pina.label_tensor.LabelTensor` are extracted.
|
||||
:type input_indeces_branch_net: list[int] | list[str]
|
||||
:param input_indeces_trunk_net: List of indeces to extract from the
|
||||
input variable of the ``trunk_net``.
|
||||
If a list of ``int`` is passed, the corresponding columns of the
|
||||
inner most entries are extracted. If a list of ``str`` is passed the
|
||||
variables of the corresponding
|
||||
:class:`~pina.label_tensor.LabelTensor` are extracted.
|
||||
:type input_indeces_trunk_net: list[int] | list[str]
|
||||
:param aggregator: The aggregator to be used to aggregate component-wise
|
||||
partial results from the modules in ``networks``. Available
|
||||
aggregators include: sum: ``+``, product: ``*``, mean: ``mean``,
|
||||
min: ``min``, max: ``max``. Default is ``*``.
|
||||
:type aggregator: str or Callable
|
||||
:param reduction: The reduction to be used to reduce the aggregated
|
||||
result of the modules in ``networks`` to the desired output
|
||||
dimension. Available reductions include: sum: ``+``, product: ``*``,
|
||||
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
|
||||
:type reduction: str or Callable
|
||||
:param bool scale: If ``True``, the final output is scaled before being
|
||||
returned in the forward pass. Default is ``True``.
|
||||
:param bool translation: If ``True``, the final output is translated
|
||||
before being returned in the forward pass. Default is ``True``.
|
||||
|
||||
.. warning::
|
||||
In the forward pass we do not check if the input is instance of
|
||||
@@ -364,6 +412,14 @@ class DeepONet(MIONet):
|
||||
Differently, for a :class:`torch.Tensor` only a list of integers can
|
||||
be passed for ``input_indeces_branch_net`` and
|
||||
``input_indeces_trunk_net``.
|
||||
.. warning::
|
||||
No checks are performed in the forward pass to verify if the input
|
||||
is instance of either :class:`~pina.label_tensor.LabelTensor` or
|
||||
:class:`torch.Tensor`. In general, in case of a
|
||||
:class:`~pina.label_tensor.LabelTensor`, both a ``list[int]`` or a
|
||||
``list[str]`` can be passed as ``input_indeces_branch_net`` and
|
||||
``input_indeces_trunk_net``. Differently, in case of a
|
||||
:class:`torch.Tensor`, only a ``list[int]`` can be passed.
|
||||
|
||||
:Example:
|
||||
>>> branch_net = FeedForward(input_dimensons=1,
|
||||
@@ -411,25 +467,31 @@ class DeepONet(MIONet):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
Forward pass for the :class:`DeepONet` model.
|
||||
|
||||
:param LabelTensor or torch.Tensor x: The input tensor for the forward
|
||||
call.
|
||||
:return: The output computed by the DeepONet model.
|
||||
:rtype: LabelTensor or torch.Tensor
|
||||
:param x: The input tensor.
|
||||
:type x: LabelTensor | torch.Tensor
|
||||
:return: The output tensor.
|
||||
:rtype: LabelTensor | torch.Tensor
|
||||
"""
|
||||
return super().forward(x)
|
||||
|
||||
@property
|
||||
def branch_net(self):
|
||||
"""
|
||||
The branch net for DeepONet.
|
||||
The branch net of the DeepONet.
|
||||
|
||||
:return: The branch net.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def trunk_net(self):
|
||||
"""
|
||||
The trunk net for DeepONet.
|
||||
The trunk net of the DeepONet.
|
||||
|
||||
:return: The trunk net.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self.models[1]
|
||||
|
||||
Reference in New Issue
Block a user