Documentation for v0.1 version (#199)
* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
3f9305d475
commit
8b7b61b3bd
@@ -17,15 +17,16 @@ class MIONet(torch.nn.Module):
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Jin, Pengzhan, Shuai Meng, and Lu Lu.
|
||||
"MIONet: Learning multiple-input operators via tensor product."
|
||||
*MIONet: Learning multiple-input operators via tensor product.*
|
||||
SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351
|
||||
DOI: `10.1137/22M1477751
|
||||
<https://doi.org/10.1137/22M1477751>`_
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
networks,
|
||||
aggregator="*",
|
||||
aggregator="*",
|
||||
reduction="+",
|
||||
scale=True,
|
||||
translation=True):
|
||||
@@ -35,28 +36,27 @@ class MIONet(torch.nn.Module):
|
||||
as value the list of indeces to extract from the input variable
|
||||
in the forward pass of the neural network. If a list of ``int`` is passed,
|
||||
the corresponding columns of the inner most entries are extracted.
|
||||
If a list of ``str`` is passed the variables of the corresponding :class:`LabelTensor`
|
||||
If a list of ``str`` is passed the variables of the corresponding :py:obj:`pina.label_tensor.LabelTensor`
|
||||
are extracted. The ``torch.nn.Module`` model has to take as input a
|
||||
:class:`LabelTensor` or :class:`torch.Tensor`. Default implementation consist of different
|
||||
branch nets and one trunk net.
|
||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
|
||||
Default implementation consist of different branch nets and one trunk nets.
|
||||
:param str or Callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. See
|
||||
:func:`pina.model.deeponet.MIONet._symbol_functions` for the
|
||||
available default aggregators.
|
||||
:param str | callable reduction: Reduction to be used to reduce
|
||||
aggregated component-wise. Available aggregators include
|
||||
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max: ``max``.
|
||||
:param str or Callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions`
|
||||
for the available default reductions.
|
||||
:param bool | callable scale: Scaling the final output before returning the
|
||||
forward pass, default True.
|
||||
:param bool | callable translation: Translating the final output before
|
||||
returning the forward pass, default True.
|
||||
dimension. Available reductions include
|
||||
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max: ``max``.
|
||||
:param bool or Callable scale: Scaling the final output before returning the
|
||||
forward pass, default ``True``.
|
||||
:param bool or Callable translation: Translating the final output before
|
||||
returning the forward pass, default ``True``.
|
||||
|
||||
.. warning::
|
||||
In the forward pass we do not check if the input is instance of
|
||||
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
|
||||
that for a :class:`LabelTensor` input both list of integers and
|
||||
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`. A general rule is
|
||||
that for a :py:obj:`pina.label_tensor.LabelTensor` input both list of integers and
|
||||
list of strings can be passed for ``input_indeces_branch_net``
|
||||
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
|
||||
only a list of integers can be passed for ``input_indeces_branch_net``
|
||||
@@ -133,8 +133,10 @@ class MIONet(torch.nn.Module):
|
||||
self._init_reduction(reduction=reduction)
|
||||
|
||||
# scale and translation
|
||||
self._scale = torch.nn.Parameter(torch.tensor([1.0])) if scale else torch.tensor([1.0])
|
||||
self._trasl = torch.nn.Parameter(torch.tensor([1.0])) if translation else torch.tensor([1.0])
|
||||
self._scale = torch.nn.Parameter(torch.tensor(
|
||||
[1.0])) if scale else torch.tensor([1.0])
|
||||
self._trasl = torch.nn.Parameter(torch.tensor(
|
||||
[1.0])) if translation else torch.tensor([1.0])
|
||||
|
||||
@staticmethod
|
||||
def _symbol_functions(**kwargs):
|
||||
@@ -149,7 +151,7 @@ class MIONet(torch.nn.Module):
|
||||
"min": lambda x: torch.min(x, **kwargs).values,
|
||||
"max": lambda x: torch.max(x, **kwargs).values,
|
||||
}
|
||||
|
||||
|
||||
def _init_aggregator(self, aggregator):
|
||||
aggregator_funcs = DeepONet._symbol_functions(dim=2)
|
||||
if aggregator in aggregator_funcs:
|
||||
@@ -161,7 +163,6 @@ class MIONet(torch.nn.Module):
|
||||
|
||||
self._aggregator = aggregator_func
|
||||
|
||||
|
||||
def _init_reduction(self, reduction):
|
||||
reduction_funcs = DeepONet._symbol_functions(dim=-1)
|
||||
if reduction in reduction_funcs:
|
||||
@@ -178,27 +179,32 @@ class MIONet(torch.nn.Module):
|
||||
try:
|
||||
return x.extract(indeces)
|
||||
except AttributeError:
|
||||
raise RuntimeError('Not possible to extract input variables from tensor.'
|
||||
' Ensure that the passed tensor is a LabelTensor or'
|
||||
' pass list of integers to extract variables. For'
|
||||
' more information refer to warning in the documentation.')
|
||||
raise RuntimeError(
|
||||
'Not possible to extract input variables from tensor.'
|
||||
' Ensure that the passed tensor is a LabelTensor or'
|
||||
' pass list of integers to extract variables. For'
|
||||
' more information refer to warning in the documentation.')
|
||||
elif isinstance(indeces[0], int):
|
||||
return x[..., indeces]
|
||||
else:
|
||||
raise RuntimeError('Not able to extract right indeces for tensor.'
|
||||
' For more information refer to warning in the documentation.')
|
||||
|
||||
raise RuntimeError(
|
||||
'Not able to extract right indeces for tensor.'
|
||||
' For more information refer to warning in the documentation.')
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
|
||||
:param LabelTensor | torch.Tensor x: The input tensor for the forward call.
|
||||
:param LabelTensor or torch.Tensor x: The input tensor for the forward call.
|
||||
:return: The output computed by the DeepONet model.
|
||||
:rtype: LabelTensor | torch.Tensor
|
||||
:rtype: LabelTensor or torch.Tensor
|
||||
"""
|
||||
|
||||
# forward pass
|
||||
output_ = [model(self._get_vars(x, indeces)) for model, indeces in zip(self.models, self._indeces)]
|
||||
output_ = [
|
||||
model(self._get_vars(x, indeces))
|
||||
for model, indeces in zip(self.models, self._indeces)
|
||||
]
|
||||
|
||||
# aggregation
|
||||
aggregated = self._aggregator(torch.dstack(output_))
|
||||
@@ -206,7 +212,7 @@ class MIONet(torch.nn.Module):
|
||||
# reduce
|
||||
output_ = self._reduction(aggregated).reshape(-1, 1)
|
||||
|
||||
# scale and translate
|
||||
# scale and translate
|
||||
output_ *= self._scale
|
||||
output_ += self._trasl
|
||||
|
||||
@@ -218,7 +224,7 @@ class MIONet(torch.nn.Module):
|
||||
The aggregator function.
|
||||
"""
|
||||
return self._aggregator
|
||||
|
||||
|
||||
@property
|
||||
def reduction(self):
|
||||
"""
|
||||
@@ -232,28 +238,28 @@ class MIONet(torch.nn.Module):
|
||||
The scale factor.
|
||||
"""
|
||||
return self._scale
|
||||
|
||||
|
||||
@property
|
||||
def translation(self):
|
||||
"""
|
||||
The translation factor for MIONet.
|
||||
"""
|
||||
return self._trasl
|
||||
|
||||
|
||||
@property
|
||||
def indeces_variables_extracted(self):
|
||||
"""
|
||||
The input indeces for each model in form of list.
|
||||
"""
|
||||
return self._indeces
|
||||
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
The models in form of list.
|
||||
"""
|
||||
return self._indeces
|
||||
|
||||
|
||||
|
||||
class DeepONet(MIONet):
|
||||
"""
|
||||
@@ -273,52 +279,52 @@ class DeepONet(MIONet):
|
||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
branch_net,
|
||||
trunk_net,
|
||||
input_indeces_branch_net,
|
||||
input_indeces_trunk_net,
|
||||
aggregator="*",
|
||||
aggregator="*",
|
||||
reduction="+",
|
||||
scale=True,
|
||||
translation=True):
|
||||
"""
|
||||
:param torch.nn.Module branch_net: The neural network to use as branch
|
||||
model. It has to take as input a :class:`LabelTensor`
|
||||
model. It has to take as input a :py:obj:`pina.label_tensor.LabelTensor`
|
||||
or :class:`torch.Tensor`. The number of dimensions of the output has
|
||||
to be the same of the ``trunk_net``.
|
||||
:param torch.nn.Module trunk_net: The neural network to use as trunk
|
||||
model. It has to take as input a :class:`LabelTensor`
|
||||
model. It has to take as input a :py:obj:`pina.label_tensor.LabelTensor`
|
||||
or :class:`torch.Tensor`. The number of dimensions of the output
|
||||
has to be the same of the ``branch_net``.
|
||||
:param list(int) | list(str) input_indeces_branch_net: List of indeces
|
||||
:param list(int) or list(str) input_indeces_branch_net: List of indeces
|
||||
to extract from the input variable in the forward pass for the
|
||||
branch net. If a list of ``int`` is passed, the corresponding columns
|
||||
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
|
||||
the variables of the corresponding :py:obj:`pina.label_tensor.LabelTensor` are extracted.
|
||||
:param list(int) or list(str) input_indeces_trunk_net: List of indeces
|
||||
to extract from the input variable in the forward pass for the
|
||||
trunk net. If a list of ``int`` is passed, the corresponding columns
|
||||
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||
the variables of the corresponding :py:obj:`pina.label_tensor.LabelTensor` are extracted.
|
||||
:param str or Callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. See
|
||||
:func:`pina.model.deeponet.MIONet._symbol_functions` for the
|
||||
available default aggregators.
|
||||
:param str | callable reduction: Reduction to be used to reduce
|
||||
aggregated component-wise. Available aggregators include
|
||||
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max: ``max``.
|
||||
:param str or Callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions` for the available default
|
||||
reductions.
|
||||
:param bool | callable scale: Scaling the final output before returning the
|
||||
dimension. Available reductions include
|
||||
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max: ``max``.
|
||||
:param bool or Callable scale: Scaling the final output before returning the
|
||||
forward pass, default True.
|
||||
:param bool | callable translation: Translating the final output before
|
||||
:param bool or Callable translation: Translating the final output before
|
||||
returning the forward pass, default True.
|
||||
|
||||
.. warning::
|
||||
In the forward pass we do not check if the input is instance of
|
||||
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
|
||||
that for a :class:`LabelTensor` input both list of integers and
|
||||
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`. A general rule is
|
||||
that for a :py:obj:`pina.label_tensor.LabelTensor` input both list of integers and
|
||||
list of strings can be passed for ``input_indeces_branch_net``
|
||||
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
|
||||
only a list of integers can be passed for ``input_indeces_branch_net``
|
||||
@@ -355,24 +361,38 @@ class DeepONet(MIONet):
|
||||
)
|
||||
)
|
||||
"""
|
||||
networks = {branch_net : input_indeces_branch_net,
|
||||
trunk_net : input_indeces_trunk_net}
|
||||
networks = {
|
||||
branch_net: input_indeces_branch_net,
|
||||
trunk_net: input_indeces_trunk_net
|
||||
}
|
||||
super().__init__(networks=networks,
|
||||
aggregator=aggregator,
|
||||
reduction=reduction,
|
||||
scale=scale,
|
||||
translation=translation)
|
||||
|
||||
@property
|
||||
def branch_net(self):
|
||||
"""
|
||||
The branch net for DeepONet.
|
||||
"""
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def trunk_net(self):
|
||||
"""
|
||||
The trunk net for DeepONet.
|
||||
"""
|
||||
return self.models[1]
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
|
||||
:param LabelTensor or torch.Tensor x: The input tensor for the forward call.
|
||||
:return: The output computed by the DeepONet model.
|
||||
:rtype: LabelTensor or torch.Tensor
|
||||
"""
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@property
|
||||
def branch_net(self):
|
||||
"""
|
||||
The branch net for DeepONet.
|
||||
"""
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def trunk_net(self):
|
||||
"""
|
||||
The trunk net for DeepONet.
|
||||
"""
|
||||
return self.models[1]
|
||||
|
||||
Reference in New Issue
Block a user