MIONet addition and DeepONet modification.
* Create MIONet base class * DeepONet inherits from MIONet
This commit is contained in:
committed by
Nicola Demo
parent
603f56d264
commit
ba7371f350
@@ -2,10 +2,11 @@ __all__ = [
|
||||
'FeedForward',
|
||||
'MultiFeedForward',
|
||||
'DeepONet',
|
||||
'MIONet',
|
||||
'FNO',
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward
|
||||
from .multi_feed_forward import MultiFeedForward
|
||||
from .deeponet import DeepONet
|
||||
from .deeponet import DeepONet, MIONet
|
||||
from .fno import FNO
|
||||
|
||||
@@ -5,59 +5,53 @@ from ..utils import check_consistency, is_function
|
||||
from functools import partial
|
||||
|
||||
|
||||
class DeepONet(torch.nn.Module):
|
||||
class MIONet(torch.nn.Module):
|
||||
"""
|
||||
The PINA implementation of DeepONet network.
|
||||
The PINA implementation of MIONet network.
|
||||
|
||||
DeepONet is a general architecture for learning Operators. Unlike
|
||||
traditional machine learning methods DeepONet is designed to map
|
||||
entire functions to other functions. It can be trained both with
|
||||
Physics Informed or Supervised learning strategies.
|
||||
MIONet is a general architecture for learning Operators defined
|
||||
on the tensor product of Banach spaces. Unlike traditional machine
|
||||
learning methods MIONet is designed to map entire functions to other functions.
|
||||
It can be trained both with Physics Informed or Supervised learning strategies.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
||||
nonlinear operators via DeepONet based on the universal approximation
|
||||
theorem of operators*. Nat Mach Intell 3, 218–229 (2021).
|
||||
DOI: `10.1038/s42256-021-00302-5
|
||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||
**Original reference**: Jin, Pengzhan, Shuai Meng, and Lu Lu.
|
||||
"MIONet: Learning multiple-input operators via tensor product."
|
||||
SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351
|
||||
DOI: `10.1137/22M1477751
|
||||
<https://doi.org/10.1137/22M1477751>`_
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
branch_net,
|
||||
trunk_net,
|
||||
input_indeces_branch_net,
|
||||
input_indeces_trunk_net,
|
||||
networks,
|
||||
aggregator="*",
|
||||
reduction="+"):
|
||||
reduction="+",
|
||||
scale=True,
|
||||
translation=True):
|
||||
"""
|
||||
:param torch.nn.Module branch_net: The neural network to use as branch
|
||||
model. It has to take as input a :class:`LabelTensor`
|
||||
or :class:`torch.Tensor`. The number of dimensions of the output has
|
||||
to be the same of the ``trunk_net``.
|
||||
:param torch.nn.Module trunk_net: The neural network to use as trunk
|
||||
model. It has to take as input a :class:`LabelTensor`
|
||||
or :class:`torch.Tensor`. The number of dimensions of the output
|
||||
has to be the same of the ``branch_net``.
|
||||
:param list(int) | list(str) input_indeces_branch_net: List of indeces
|
||||
to extract from the input variable in the forward pass for the
|
||||
branch net. If a list of ``int`` is passed, the corresponding columns
|
||||
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
|
||||
to extract from the input variable in the forward pass for the
|
||||
trunk net. If a list of ``int`` is passed, the corresponding columns
|
||||
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||
:param dict networks: The neural networks to use as
|
||||
models. The ``dict`` takes as key a neural network, and
|
||||
as value the list of indeces to extract from the input variable
|
||||
in the forward pass of the neural network. If a list of ``int`` is passed,
|
||||
the corresponding columns of the inner most entries are extracted.
|
||||
If a list of ``str`` is passed the variables of the corresponding :class:`LabelTensor`
|
||||
are extracted. The ``torch.nn.Module`` model has to take as input a
|
||||
:class:`LabelTensor` or :class:`torch.Tensor`. Default implementation consist of different
|
||||
branch nets and one trunk net.
|
||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. See
|
||||
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
|
||||
:func:`pina.model.deeponet.MIONet._symbol_functions` for the
|
||||
available default aggregators.
|
||||
:param str | callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
|
||||
reductions.
|
||||
dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions`
|
||||
for the available default reductions.
|
||||
:param bool | callable scale: Scaling the final output before returning the
|
||||
forward pass, default True.
|
||||
:param bool | callable translation: Translating the final output before
|
||||
returning the forward pass, default True.
|
||||
|
||||
.. warning::
|
||||
In the forward pass we do not check if the input is instance of
|
||||
@@ -69,32 +63,44 @@ class DeepONet(torch.nn.Module):
|
||||
and ``input_indeces_trunk_net``.
|
||||
|
||||
:Example:
|
||||
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
>>> branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
>>> branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
>>> model = DeepONet(branch_net=branch_net,
|
||||
... trunk_net=trunk_net,
|
||||
... input_indeces_branch_net=['x'],
|
||||
... input_indeces_trunk_net=['t'],
|
||||
... reduction='+',
|
||||
... aggregator='*')
|
||||
>>> networks = {branch_net1 : ['x'],
|
||||
branch_net2 : ['x', 'y'],
|
||||
... trunk_net : ['z']}
|
||||
>>> model = MIONet(networks=networks,
|
||||
... reduction='+',
|
||||
... aggregator='*')
|
||||
>>> model
|
||||
DeepONet(
|
||||
(trunk_net): FeedForward(
|
||||
MIONet(
|
||||
(models): ModuleList(
|
||||
(0): FeedForward(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
)
|
||||
)
|
||||
(branch_net): FeedForward(
|
||||
)
|
||||
(1): FeedForward(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
(0): Linear(in_features=2, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
)
|
||||
)
|
||||
(2): FeedForward(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -102,33 +108,33 @@ class DeepONet(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
# check type consistency
|
||||
check_consistency(input_indeces_branch_net, (str, int))
|
||||
check_consistency(input_indeces_trunk_net, (str, int))
|
||||
check_consistency(branch_net, torch.nn.Module)
|
||||
check_consistency(trunk_net, torch.nn.Module)
|
||||
check_consistency(networks, dict)
|
||||
check_consistency(scale, bool)
|
||||
check_consistency(translation, bool)
|
||||
|
||||
# check trunk branch nets consistency
|
||||
input_trunk = torch.rand(10, len(input_indeces_trunk_net))
|
||||
input_branch = torch.rand(10, len(input_indeces_branch_net))
|
||||
trunk_out_dim = trunk_net(input_trunk).shape[-1]
|
||||
branch_out_dim = branch_net(input_branch).shape[-1]
|
||||
if trunk_out_dim != branch_out_dim:
|
||||
raise ValueError('Branch and trunk networks have not the same '
|
||||
shapes = []
|
||||
for key, value in networks.items():
|
||||
check_consistency(value, (str, int))
|
||||
check_consistency(key, torch.nn.Module)
|
||||
input_ = torch.rand(10, len(value))
|
||||
shapes.append(key(input_).shape[-1])
|
||||
|
||||
if not all(map(lambda x: x == shapes[0], shapes)):
|
||||
raise ValueError('The passed networks have not the same '
|
||||
'output dimension.')
|
||||
|
||||
# assign trunk and branch net with their input indeces
|
||||
self.trunk_net = trunk_net
|
||||
self._trunk_indeces = input_indeces_trunk_net
|
||||
self.branch_net = branch_net
|
||||
self._branch_indeces = input_indeces_branch_net
|
||||
self.models = torch.nn.ModuleList(networks.keys())
|
||||
self._indeces = networks.values()
|
||||
|
||||
# initializie aggregation
|
||||
self._init_aggregator(aggregator=aggregator)
|
||||
self._init_reduction(reduction=reduction)
|
||||
|
||||
# scale and translation
|
||||
self._scale = torch.nn.Parameter(torch.tensor([1.0]))
|
||||
self._trasl = torch.nn.Parameter(torch.tensor([0.0]))
|
||||
self._scale = torch.nn.Parameter(torch.tensor([1.0])) if scale else torch.tensor([1.0])
|
||||
self._trasl = torch.nn.Parameter(torch.tensor([1.0])) if translation else torch.tensor([1.0])
|
||||
|
||||
@staticmethod
|
||||
def _symbol_functions(**kwargs):
|
||||
@@ -192,11 +198,10 @@ class DeepONet(torch.nn.Module):
|
||||
"""
|
||||
|
||||
# forward pass
|
||||
branch_output = self.branch_net(self._get_vars(x, self._branch_indeces))
|
||||
trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces))
|
||||
output_ = [model(self._get_vars(x, indeces)) for model, indeces in zip(self.models, self._indeces)]
|
||||
|
||||
# aggregation
|
||||
aggregated = self._aggregator(torch.dstack((branch_output, trunk_output)))
|
||||
aggregated = self._aggregator(torch.dstack(output_))
|
||||
|
||||
# reduce
|
||||
output_ = self._reduction(aggregated).reshape(-1, 1)
|
||||
@@ -205,4 +210,169 @@ class DeepONet(torch.nn.Module):
|
||||
output_ *= self._scale
|
||||
output_ += self._trasl
|
||||
|
||||
return output_
|
||||
return output_
|
||||
|
||||
@property
|
||||
def aggregator(self):
|
||||
"""
|
||||
The aggregator function.
|
||||
"""
|
||||
return self._aggregator
|
||||
|
||||
@property
|
||||
def reduction(self):
|
||||
"""
|
||||
The translation factor.
|
||||
"""
|
||||
return self._reduction
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
"""
|
||||
The scale factor.
|
||||
"""
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def translation(self):
|
||||
"""
|
||||
The translation factor for MIONet.
|
||||
"""
|
||||
return self._trasl
|
||||
|
||||
@property
|
||||
def indeces_variables_extracted(self):
|
||||
"""
|
||||
The input indeces for each model in form of list.
|
||||
"""
|
||||
return self._indeces
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
The models in form of list.
|
||||
"""
|
||||
return self._indeces
|
||||
|
||||
|
||||
class DeepONet(MIONet):
|
||||
"""
|
||||
The PINA implementation of DeepONet network.
|
||||
|
||||
DeepONet is a general architecture for learning Operators. Unlike
|
||||
traditional machine learning methods DeepONet is designed to map
|
||||
entire functions to other functions. It can be trained both with
|
||||
Physics Informed or Supervised learning strategies.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
||||
nonlinear operators via DeepONet based on the universal approximation
|
||||
theorem of operators*. Nat Mach Intell 3, 218–229 (2021).
|
||||
DOI: `10.1038/s42256-021-00302-5
|
||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
branch_net,
|
||||
trunk_net,
|
||||
input_indeces_branch_net,
|
||||
input_indeces_trunk_net,
|
||||
aggregator="*",
|
||||
reduction="+",
|
||||
scale=True,
|
||||
translation=True):
|
||||
"""
|
||||
:param torch.nn.Module branch_net: The neural network to use as branch
|
||||
model. It has to take as input a :class:`LabelTensor`
|
||||
or :class:`torch.Tensor`. The number of dimensions of the output has
|
||||
to be the same of the ``trunk_net``.
|
||||
:param torch.nn.Module trunk_net: The neural network to use as trunk
|
||||
model. It has to take as input a :class:`LabelTensor`
|
||||
or :class:`torch.Tensor`. The number of dimensions of the output
|
||||
has to be the same of the ``branch_net``.
|
||||
:param list(int) | list(str) input_indeces_branch_net: List of indeces
|
||||
to extract from the input variable in the forward pass for the
|
||||
branch net. If a list of ``int`` is passed, the corresponding columns
|
||||
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
|
||||
to extract from the input variable in the forward pass for the
|
||||
trunk net. If a list of ``int`` is passed, the corresponding columns
|
||||
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. See
|
||||
:func:`pina.model.deeponet.MIONet._symbol_functions` for the
|
||||
available default aggregators.
|
||||
:param str | callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions` for the available default
|
||||
reductions.
|
||||
:param bool | callable scale: Scaling the final output before returning the
|
||||
forward pass, default True.
|
||||
:param bool | callable translation: Translating the final output before
|
||||
returning the forward pass, default True.
|
||||
|
||||
.. warning::
|
||||
In the forward pass we do not check if the input is instance of
|
||||
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
|
||||
that for a :class:`LabelTensor` input both list of integers and
|
||||
list of strings can be passed for ``input_indeces_branch_net``
|
||||
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
|
||||
only a list of integers can be passed for ``input_indeces_branch_net``
|
||||
and ``input_indeces_trunk_net``.
|
||||
|
||||
:Example:
|
||||
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
>>> model = DeepONet(branch_net=branch_net,
|
||||
... trunk_net=trunk_net,
|
||||
... input_indeces_branch_net=['x'],
|
||||
... input_indeces_trunk_net=['t'],
|
||||
... reduction='+',
|
||||
... aggregator='*')
|
||||
>>> model
|
||||
DeepONet(
|
||||
(trunk_net): FeedForward(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
)
|
||||
)
|
||||
(branch_net): FeedForward(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
"""
|
||||
networks = {branch_net : input_indeces_branch_net,
|
||||
trunk_net : input_indeces_trunk_net}
|
||||
super().__init__(networks=networks,
|
||||
aggregator=aggregator,
|
||||
reduction=reduction,
|
||||
scale=scale,
|
||||
translation=translation)
|
||||
|
||||
@property
|
||||
def branch_net(self):
|
||||
"""
|
||||
The branch net for DeepONet.
|
||||
"""
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def trunk_net(self):
|
||||
"""
|
||||
The trunk net for DeepONet.
|
||||
"""
|
||||
return self.models[1]
|
||||
Reference in New Issue
Block a user