MIONet addition and DeepONet modification.

* Create MIONet base class
* DeepONet inherits from MIONet
This commit is contained in:
Dario Coscia
2023-09-13 09:46:52 +02:00
committed by Nicola Demo
parent 603f56d264
commit ba7371f350
3 changed files with 322 additions and 79 deletions

View File

@@ -2,10 +2,11 @@ __all__ = [
'FeedForward',
'MultiFeedForward',
'DeepONet',
'MIONet',
'FNO',
]
from .feed_forward import FeedForward
from .multi_feed_forward import MultiFeedForward
from .deeponet import DeepONet
from .deeponet import DeepONet, MIONet
from .fno import FNO

View File

@@ -5,59 +5,53 @@ from ..utils import check_consistency, is_function
from functools import partial
class DeepONet(torch.nn.Module):
class MIONet(torch.nn.Module):
"""
The PINA implementation of DeepONet network.
The PINA implementation of MIONet network.
DeepONet is a general architecture for learning Operators. Unlike
traditional machine learning methods DeepONet is designed to map
entire functions to other functions. It can be trained both with
Physics Informed or Supervised learning strategies.
MIONet is a general architecture for learning Operators defined
on the tensor product of Banach spaces. Unlike traditional machine
learning methods MIONet is designed to map entire functions to other functions.
It can be trained both with Physics Informed or Supervised learning strategies.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
nonlinear operators via DeepONet based on the universal approximation
theorem of operators*. Nat Mach Intell 3, 218229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
**Original reference**: Jin, Pengzhan, Shuai Meng, and Lu Lu.
"MIONet: Learning multiple-input operators via tensor product."
SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351
DOI: `10.1137/22M1477751
<https://doi.org/10.1137/22M1477751>`_
"""
def __init__(self,
branch_net,
trunk_net,
input_indeces_branch_net,
input_indeces_trunk_net,
networks,
aggregator="*",
reduction="+"):
reduction="+",
scale=True,
translation=True):
"""
:param torch.nn.Module branch_net: The neural network to use as branch
model. It has to take as input a :class:`LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output has
to be the same of the ``trunk_net``.
:param torch.nn.Module trunk_net: The neural network to use as trunk
model. It has to take as input a :class:`LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output
has to be the same of the ``branch_net``.
:param list(int) | list(str) input_indeces_branch_net: List of indeces
to extract from the input variable in the forward pass for the
branch net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :class:`LabelTensor` are extracted.
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
to extract from the input variable in the forward pass for the
trunk net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :class:`LabelTensor` are extracted.
:param dict networks: The neural networks to use as
models. The ``dict`` takes as key a neural network, and
as value the list of indeces to extract from the input variable
in the forward pass of the neural network. If a list of ``int`` is passed,
the corresponding columns of the inner most entries are extracted.
If a list of ``str`` is passed the variables of the corresponding :class:`LabelTensor`
are extracted. The ``torch.nn.Module`` model has to take as input a
:class:`LabelTensor` or :class:`torch.Tensor`. Default implementation consist of different
branch nets and one trunk net.
:param str | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. See
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
:func:`pina.model.deeponet.MIONet._symbol_functions` for the
available default aggregators.
:param str | callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
reductions.
dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions`
for the available default reductions.
:param bool | callable scale: Scaling the final output before returning the
forward pass, default True.
:param bool | callable translation: Translating the final output before
returning the forward pass, default True.
.. warning::
In the forward pass we do not check if the input is instance of
@@ -69,32 +63,44 @@ class DeepONet(torch.nn.Module):
and ``input_indeces_trunk_net``.
:Example:
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
>>> branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> model = DeepONet(branch_net=branch_net,
... trunk_net=trunk_net,
... input_indeces_branch_net=['x'],
... input_indeces_trunk_net=['t'],
... reduction='+',
... aggregator='*')
>>> networks = {branch_net1 : ['x'],
branch_net2 : ['x', 'y'],
... trunk_net : ['z']}
>>> model = MIONet(networks=networks,
... reduction='+',
... aggregator='*')
>>> model
DeepONet(
(trunk_net): FeedForward(
MIONet(
(models): ModuleList(
(0): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(branch_net): FeedForward(
)
(1): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
(0): Linear(in_features=2, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(2): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
)
)
@@ -102,33 +108,33 @@ class DeepONet(torch.nn.Module):
super().__init__()
# check type consistency
check_consistency(input_indeces_branch_net, (str, int))
check_consistency(input_indeces_trunk_net, (str, int))
check_consistency(branch_net, torch.nn.Module)
check_consistency(trunk_net, torch.nn.Module)
check_consistency(networks, dict)
check_consistency(scale, bool)
check_consistency(translation, bool)
# check trunk branch nets consistency
input_trunk = torch.rand(10, len(input_indeces_trunk_net))
input_branch = torch.rand(10, len(input_indeces_branch_net))
trunk_out_dim = trunk_net(input_trunk).shape[-1]
branch_out_dim = branch_net(input_branch).shape[-1]
if trunk_out_dim != branch_out_dim:
raise ValueError('Branch and trunk networks have not the same '
shapes = []
for key, value in networks.items():
check_consistency(value, (str, int))
check_consistency(key, torch.nn.Module)
input_ = torch.rand(10, len(value))
shapes.append(key(input_).shape[-1])
if not all(map(lambda x: x == shapes[0], shapes)):
raise ValueError('The passed networks have not the same '
'output dimension.')
# assign trunk and branch net with their input indeces
self.trunk_net = trunk_net
self._trunk_indeces = input_indeces_trunk_net
self.branch_net = branch_net
self._branch_indeces = input_indeces_branch_net
self.models = torch.nn.ModuleList(networks.keys())
self._indeces = networks.values()
# initializie aggregation
self._init_aggregator(aggregator=aggregator)
self._init_reduction(reduction=reduction)
# scale and translation
self._scale = torch.nn.Parameter(torch.tensor([1.0]))
self._trasl = torch.nn.Parameter(torch.tensor([0.0]))
self._scale = torch.nn.Parameter(torch.tensor([1.0])) if scale else torch.tensor([1.0])
self._trasl = torch.nn.Parameter(torch.tensor([1.0])) if translation else torch.tensor([1.0])
@staticmethod
def _symbol_functions(**kwargs):
@@ -192,11 +198,10 @@ class DeepONet(torch.nn.Module):
"""
# forward pass
branch_output = self.branch_net(self._get_vars(x, self._branch_indeces))
trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces))
output_ = [model(self._get_vars(x, indeces)) for model, indeces in zip(self.models, self._indeces)]
# aggregation
aggregated = self._aggregator(torch.dstack((branch_output, trunk_output)))
aggregated = self._aggregator(torch.dstack(output_))
# reduce
output_ = self._reduction(aggregated).reshape(-1, 1)
@@ -205,4 +210,169 @@ class DeepONet(torch.nn.Module):
output_ *= self._scale
output_ += self._trasl
return output_
return output_
@property
def aggregator(self):
"""
The aggregator function.
"""
return self._aggregator
@property
def reduction(self):
"""
The translation factor.
"""
return self._reduction
@property
def scale(self):
"""
The scale factor.
"""
return self._scale
@property
def translation(self):
"""
The translation factor for MIONet.
"""
return self._trasl
@property
def indeces_variables_extracted(self):
"""
The input indeces for each model in form of list.
"""
return self._indeces
@property
def model(self):
"""
The models in form of list.
"""
return self._indeces
class DeepONet(MIONet):
"""
The PINA implementation of DeepONet network.
DeepONet is a general architecture for learning Operators. Unlike
traditional machine learning methods DeepONet is designed to map
entire functions to other functions. It can be trained both with
Physics Informed or Supervised learning strategies.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
nonlinear operators via DeepONet based on the universal approximation
theorem of operators*. Nat Mach Intell 3, 218229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
"""
def __init__(self,
branch_net,
trunk_net,
input_indeces_branch_net,
input_indeces_trunk_net,
aggregator="*",
reduction="+",
scale=True,
translation=True):
"""
:param torch.nn.Module branch_net: The neural network to use as branch
model. It has to take as input a :class:`LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output has
to be the same of the ``trunk_net``.
:param torch.nn.Module trunk_net: The neural network to use as trunk
model. It has to take as input a :class:`LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output
has to be the same of the ``branch_net``.
:param list(int) | list(str) input_indeces_branch_net: List of indeces
to extract from the input variable in the forward pass for the
branch net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :class:`LabelTensor` are extracted.
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
to extract from the input variable in the forward pass for the
trunk net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :class:`LabelTensor` are extracted.
:param str | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. See
:func:`pina.model.deeponet.MIONet._symbol_functions` for the
available default aggregators.
:param str | callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions` for the available default
reductions.
:param bool | callable scale: Scaling the final output before returning the
forward pass, default True.
:param bool | callable translation: Translating the final output before
returning the forward pass, default True.
.. warning::
In the forward pass we do not check if the input is instance of
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
that for a :class:`LabelTensor` input both list of integers and
list of strings can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
only a list of integers can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``.
:Example:
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> model = DeepONet(branch_net=branch_net,
... trunk_net=trunk_net,
... input_indeces_branch_net=['x'],
... input_indeces_trunk_net=['t'],
... reduction='+',
... aggregator='*')
>>> model
DeepONet(
(trunk_net): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(branch_net): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
)
"""
networks = {branch_net : input_indeces_branch_net,
trunk_net : input_indeces_trunk_net}
super().__init__(networks=networks,
aggregator=aggregator,
reduction=reduction,
scale=scale,
translation=translation)
@property
def branch_net(self):
"""
The branch net for DeepONet.
"""
return self.models[0]
@property
def trunk_net(self):
"""
The trunk net for DeepONet.
"""
return self.models[1]