MIONet addition and DeepONet modification.
* Create MIONet base class * DeepONet inherits from MIONet
This commit is contained in:
committed by
Nicola Demo
parent
603f56d264
commit
ba7371f350
@@ -2,10 +2,11 @@ __all__ = [
|
|||||||
'FeedForward',
|
'FeedForward',
|
||||||
'MultiFeedForward',
|
'MultiFeedForward',
|
||||||
'DeepONet',
|
'DeepONet',
|
||||||
|
'MIONet',
|
||||||
'FNO',
|
'FNO',
|
||||||
]
|
]
|
||||||
|
|
||||||
from .feed_forward import FeedForward
|
from .feed_forward import FeedForward
|
||||||
from .multi_feed_forward import MultiFeedForward
|
from .multi_feed_forward import MultiFeedForward
|
||||||
from .deeponet import DeepONet
|
from .deeponet import DeepONet, MIONet
|
||||||
from .fno import FNO
|
from .fno import FNO
|
||||||
|
|||||||
@@ -5,59 +5,53 @@ from ..utils import check_consistency, is_function
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
class DeepONet(torch.nn.Module):
|
class MIONet(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
The PINA implementation of DeepONet network.
|
The PINA implementation of MIONet network.
|
||||||
|
|
||||||
DeepONet is a general architecture for learning Operators. Unlike
|
MIONet is a general architecture for learning Operators defined
|
||||||
traditional machine learning methods DeepONet is designed to map
|
on the tensor product of Banach spaces. Unlike traditional machine
|
||||||
entire functions to other functions. It can be trained both with
|
learning methods MIONet is designed to map entire functions to other functions.
|
||||||
Physics Informed or Supervised learning strategies.
|
It can be trained both with Physics Informed or Supervised learning strategies.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
**Original reference**: Jin, Pengzhan, Shuai Meng, and Lu Lu.
|
||||||
nonlinear operators via DeepONet based on the universal approximation
|
"MIONet: Learning multiple-input operators via tensor product."
|
||||||
theorem of operators*. Nat Mach Intell 3, 218–229 (2021).
|
SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351
|
||||||
DOI: `10.1038/s42256-021-00302-5
|
DOI: `10.1137/22M1477751
|
||||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
<https://doi.org/10.1137/22M1477751>`_
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
branch_net,
|
networks,
|
||||||
trunk_net,
|
|
||||||
input_indeces_branch_net,
|
|
||||||
input_indeces_trunk_net,
|
|
||||||
aggregator="*",
|
aggregator="*",
|
||||||
reduction="+"):
|
reduction="+",
|
||||||
|
scale=True,
|
||||||
|
translation=True):
|
||||||
"""
|
"""
|
||||||
:param torch.nn.Module branch_net: The neural network to use as branch
|
:param dict networks: The neural networks to use as
|
||||||
model. It has to take as input a :class:`LabelTensor`
|
models. The ``dict`` takes as key a neural network, and
|
||||||
or :class:`torch.Tensor`. The number of dimensions of the output has
|
as value the list of indeces to extract from the input variable
|
||||||
to be the same of the ``trunk_net``.
|
in the forward pass of the neural network. If a list of ``int`` is passed,
|
||||||
:param torch.nn.Module trunk_net: The neural network to use as trunk
|
the corresponding columns of the inner most entries are extracted.
|
||||||
model. It has to take as input a :class:`LabelTensor`
|
If a list of ``str`` is passed the variables of the corresponding :class:`LabelTensor`
|
||||||
or :class:`torch.Tensor`. The number of dimensions of the output
|
are extracted. The ``torch.nn.Module`` model has to take as input a
|
||||||
has to be the same of the ``branch_net``.
|
:class:`LabelTensor` or :class:`torch.Tensor`. Default implementation consist of different
|
||||||
:param list(int) | list(str) input_indeces_branch_net: List of indeces
|
branch nets and one trunk net.
|
||||||
to extract from the input variable in the forward pass for the
|
|
||||||
branch net. If a list of ``int`` is passed, the corresponding columns
|
|
||||||
of the inner most entries are extracted. If a list of ``str`` is passed
|
|
||||||
the variables of the corresponding :class:`LabelTensor` are extracted.
|
|
||||||
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
|
|
||||||
to extract from the input variable in the forward pass for the
|
|
||||||
trunk net. If a list of ``int`` is passed, the corresponding columns
|
|
||||||
of the inner most entries are extracted. If a list of ``str`` is passed
|
|
||||||
the variables of the corresponding :class:`LabelTensor` are extracted.
|
|
||||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||||
partial results from the modules in `nets`. Partial results are
|
partial results from the modules in `nets`. Partial results are
|
||||||
aggregated component-wise. See
|
aggregated component-wise. See
|
||||||
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
|
:func:`pina.model.deeponet.MIONet._symbol_functions` for the
|
||||||
available default aggregators.
|
available default aggregators.
|
||||||
:param str | callable reduction: Reduction to be used to reduce
|
:param str | callable reduction: Reduction to be used to reduce
|
||||||
the aggregated result of the modules in `nets` to the desired output
|
the aggregated result of the modules in `nets` to the desired output
|
||||||
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
|
dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions`
|
||||||
reductions.
|
for the available default reductions.
|
||||||
|
:param bool | callable scale: Scaling the final output before returning the
|
||||||
|
forward pass, default True.
|
||||||
|
:param bool | callable translation: Translating the final output before
|
||||||
|
returning the forward pass, default True.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
In the forward pass we do not check if the input is instance of
|
In the forward pass we do not check if the input is instance of
|
||||||
@@ -69,32 +63,44 @@ class DeepONet(torch.nn.Module):
|
|||||||
and ``input_indeces_trunk_net``.
|
and ``input_indeces_trunk_net``.
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
>>> branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
>>> branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||||
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
>>> model = DeepONet(branch_net=branch_net,
|
>>> networks = {branch_net1 : ['x'],
|
||||||
... trunk_net=trunk_net,
|
branch_net2 : ['x', 'y'],
|
||||||
... input_indeces_branch_net=['x'],
|
... trunk_net : ['z']}
|
||||||
... input_indeces_trunk_net=['t'],
|
>>> model = MIONet(networks=networks,
|
||||||
... reduction='+',
|
... reduction='+',
|
||||||
... aggregator='*')
|
... aggregator='*')
|
||||||
>>> model
|
>>> model
|
||||||
DeepONet(
|
MIONet(
|
||||||
(trunk_net): FeedForward(
|
(models): ModuleList(
|
||||||
|
(0): FeedForward(
|
||||||
(model): Sequential(
|
(model): Sequential(
|
||||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||||
(1): Tanh()
|
(1): Tanh()
|
||||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
(3): Tanh()
|
(3): Tanh()
|
||||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
(branch_net): FeedForward(
|
(1): FeedForward(
|
||||||
(model): Sequential(
|
(model): Sequential(
|
||||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
(0): Linear(in_features=2, out_features=20, bias=True)
|
||||||
(1): Tanh()
|
(1): Tanh()
|
||||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
(3): Tanh()
|
(3): Tanh()
|
||||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(2): FeedForward(
|
||||||
|
(model): Sequential(
|
||||||
|
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||||
|
(1): Tanh()
|
||||||
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
|
(3): Tanh()
|
||||||
|
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -102,33 +108,33 @@ class DeepONet(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# check type consistency
|
# check type consistency
|
||||||
check_consistency(input_indeces_branch_net, (str, int))
|
check_consistency(networks, dict)
|
||||||
check_consistency(input_indeces_trunk_net, (str, int))
|
check_consistency(scale, bool)
|
||||||
check_consistency(branch_net, torch.nn.Module)
|
check_consistency(translation, bool)
|
||||||
check_consistency(trunk_net, torch.nn.Module)
|
|
||||||
|
|
||||||
# check trunk branch nets consistency
|
# check trunk branch nets consistency
|
||||||
input_trunk = torch.rand(10, len(input_indeces_trunk_net))
|
shapes = []
|
||||||
input_branch = torch.rand(10, len(input_indeces_branch_net))
|
for key, value in networks.items():
|
||||||
trunk_out_dim = trunk_net(input_trunk).shape[-1]
|
check_consistency(value, (str, int))
|
||||||
branch_out_dim = branch_net(input_branch).shape[-1]
|
check_consistency(key, torch.nn.Module)
|
||||||
if trunk_out_dim != branch_out_dim:
|
input_ = torch.rand(10, len(value))
|
||||||
raise ValueError('Branch and trunk networks have not the same '
|
shapes.append(key(input_).shape[-1])
|
||||||
|
|
||||||
|
if not all(map(lambda x: x == shapes[0], shapes)):
|
||||||
|
raise ValueError('The passed networks have not the same '
|
||||||
'output dimension.')
|
'output dimension.')
|
||||||
|
|
||||||
# assign trunk and branch net with their input indeces
|
# assign trunk and branch net with their input indeces
|
||||||
self.trunk_net = trunk_net
|
self.models = torch.nn.ModuleList(networks.keys())
|
||||||
self._trunk_indeces = input_indeces_trunk_net
|
self._indeces = networks.values()
|
||||||
self.branch_net = branch_net
|
|
||||||
self._branch_indeces = input_indeces_branch_net
|
|
||||||
|
|
||||||
# initializie aggregation
|
# initializie aggregation
|
||||||
self._init_aggregator(aggregator=aggregator)
|
self._init_aggregator(aggregator=aggregator)
|
||||||
self._init_reduction(reduction=reduction)
|
self._init_reduction(reduction=reduction)
|
||||||
|
|
||||||
# scale and translation
|
# scale and translation
|
||||||
self._scale = torch.nn.Parameter(torch.tensor([1.0]))
|
self._scale = torch.nn.Parameter(torch.tensor([1.0])) if scale else torch.tensor([1.0])
|
||||||
self._trasl = torch.nn.Parameter(torch.tensor([0.0]))
|
self._trasl = torch.nn.Parameter(torch.tensor([1.0])) if translation else torch.tensor([1.0])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _symbol_functions(**kwargs):
|
def _symbol_functions(**kwargs):
|
||||||
@@ -192,11 +198,10 @@ class DeepONet(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
branch_output = self.branch_net(self._get_vars(x, self._branch_indeces))
|
output_ = [model(self._get_vars(x, indeces)) for model, indeces in zip(self.models, self._indeces)]
|
||||||
trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces))
|
|
||||||
|
|
||||||
# aggregation
|
# aggregation
|
||||||
aggregated = self._aggregator(torch.dstack((branch_output, trunk_output)))
|
aggregated = self._aggregator(torch.dstack(output_))
|
||||||
|
|
||||||
# reduce
|
# reduce
|
||||||
output_ = self._reduction(aggregated).reshape(-1, 1)
|
output_ = self._reduction(aggregated).reshape(-1, 1)
|
||||||
@@ -205,4 +210,169 @@ class DeepONet(torch.nn.Module):
|
|||||||
output_ *= self._scale
|
output_ *= self._scale
|
||||||
output_ += self._trasl
|
output_ += self._trasl
|
||||||
|
|
||||||
return output_
|
return output_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def aggregator(self):
|
||||||
|
"""
|
||||||
|
The aggregator function.
|
||||||
|
"""
|
||||||
|
return self._aggregator
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reduction(self):
|
||||||
|
"""
|
||||||
|
The translation factor.
|
||||||
|
"""
|
||||||
|
return self._reduction
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self):
|
||||||
|
"""
|
||||||
|
The scale factor.
|
||||||
|
"""
|
||||||
|
return self._scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def translation(self):
|
||||||
|
"""
|
||||||
|
The translation factor for MIONet.
|
||||||
|
"""
|
||||||
|
return self._trasl
|
||||||
|
|
||||||
|
@property
|
||||||
|
def indeces_variables_extracted(self):
|
||||||
|
"""
|
||||||
|
The input indeces for each model in form of list.
|
||||||
|
"""
|
||||||
|
return self._indeces
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
"""
|
||||||
|
The models in form of list.
|
||||||
|
"""
|
||||||
|
return self._indeces
|
||||||
|
|
||||||
|
|
||||||
|
class DeepONet(MIONet):
|
||||||
|
"""
|
||||||
|
The PINA implementation of DeepONet network.
|
||||||
|
|
||||||
|
DeepONet is a general architecture for learning Operators. Unlike
|
||||||
|
traditional machine learning methods DeepONet is designed to map
|
||||||
|
entire functions to other functions. It can be trained both with
|
||||||
|
Physics Informed or Supervised learning strategies.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
||||||
|
nonlinear operators via DeepONet based on the universal approximation
|
||||||
|
theorem of operators*. Nat Mach Intell 3, 218–229 (2021).
|
||||||
|
DOI: `10.1038/s42256-021-00302-5
|
||||||
|
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
branch_net,
|
||||||
|
trunk_net,
|
||||||
|
input_indeces_branch_net,
|
||||||
|
input_indeces_trunk_net,
|
||||||
|
aggregator="*",
|
||||||
|
reduction="+",
|
||||||
|
scale=True,
|
||||||
|
translation=True):
|
||||||
|
"""
|
||||||
|
:param torch.nn.Module branch_net: The neural network to use as branch
|
||||||
|
model. It has to take as input a :class:`LabelTensor`
|
||||||
|
or :class:`torch.Tensor`. The number of dimensions of the output has
|
||||||
|
to be the same of the ``trunk_net``.
|
||||||
|
:param torch.nn.Module trunk_net: The neural network to use as trunk
|
||||||
|
model. It has to take as input a :class:`LabelTensor`
|
||||||
|
or :class:`torch.Tensor`. The number of dimensions of the output
|
||||||
|
has to be the same of the ``branch_net``.
|
||||||
|
:param list(int) | list(str) input_indeces_branch_net: List of indeces
|
||||||
|
to extract from the input variable in the forward pass for the
|
||||||
|
branch net. If a list of ``int`` is passed, the corresponding columns
|
||||||
|
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||||
|
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||||
|
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
|
||||||
|
to extract from the input variable in the forward pass for the
|
||||||
|
trunk net. If a list of ``int`` is passed, the corresponding columns
|
||||||
|
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||||
|
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||||
|
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||||
|
partial results from the modules in `nets`. Partial results are
|
||||||
|
aggregated component-wise. See
|
||||||
|
:func:`pina.model.deeponet.MIONet._symbol_functions` for the
|
||||||
|
available default aggregators.
|
||||||
|
:param str | callable reduction: Reduction to be used to reduce
|
||||||
|
the aggregated result of the modules in `nets` to the desired output
|
||||||
|
dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions` for the available default
|
||||||
|
reductions.
|
||||||
|
:param bool | callable scale: Scaling the final output before returning the
|
||||||
|
forward pass, default True.
|
||||||
|
:param bool | callable translation: Translating the final output before
|
||||||
|
returning the forward pass, default True.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
In the forward pass we do not check if the input is instance of
|
||||||
|
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
|
||||||
|
that for a :class:`LabelTensor` input both list of integers and
|
||||||
|
list of strings can be passed for ``input_indeces_branch_net``
|
||||||
|
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
|
||||||
|
only a list of integers can be passed for ``input_indeces_branch_net``
|
||||||
|
and ``input_indeces_trunk_net``.
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
>>> model = DeepONet(branch_net=branch_net,
|
||||||
|
... trunk_net=trunk_net,
|
||||||
|
... input_indeces_branch_net=['x'],
|
||||||
|
... input_indeces_trunk_net=['t'],
|
||||||
|
... reduction='+',
|
||||||
|
... aggregator='*')
|
||||||
|
>>> model
|
||||||
|
DeepONet(
|
||||||
|
(trunk_net): FeedForward(
|
||||||
|
(model): Sequential(
|
||||||
|
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||||
|
(1): Tanh()
|
||||||
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
|
(3): Tanh()
|
||||||
|
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(branch_net): FeedForward(
|
||||||
|
(model): Sequential(
|
||||||
|
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||||
|
(1): Tanh()
|
||||||
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
|
(3): Tanh()
|
||||||
|
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
networks = {branch_net : input_indeces_branch_net,
|
||||||
|
trunk_net : input_indeces_trunk_net}
|
||||||
|
super().__init__(networks=networks,
|
||||||
|
aggregator=aggregator,
|
||||||
|
reduction=reduction,
|
||||||
|
scale=scale,
|
||||||
|
translation=translation)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def branch_net(self):
|
||||||
|
"""
|
||||||
|
The branch net for DeepONet.
|
||||||
|
"""
|
||||||
|
return self.models[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trunk_net(self):
|
||||||
|
"""
|
||||||
|
The trunk net for DeepONet.
|
||||||
|
"""
|
||||||
|
return self.models[1]
|
||||||
72
tests/test_model/test_mionet.py
Normal file
72
tests/test_model/test_mionet.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pina import LabelTensor
|
||||||
|
from pina.model import MIONet
|
||||||
|
from pina.model import FeedForward
|
||||||
|
|
||||||
|
data = torch.rand((20, 3))
|
||||||
|
input_vars = ['a', 'b', 'c']
|
||||||
|
input_ = LabelTensor(data, input_vars)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||||
|
trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
networks = {branch_net1 : ['x'],
|
||||||
|
branch_net2 : ['x', 'y'],
|
||||||
|
trunk_net : ['z']}
|
||||||
|
MIONet(networks=networks,
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor_fails_when_invalid_inner_layer_size():
|
||||||
|
branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||||
|
trunk_net = FeedForward(input_dimensons=1, output_dimensions=12)
|
||||||
|
networks = {branch_net1 : ['x'],
|
||||||
|
branch_net2 : ['x', 'y'],
|
||||||
|
trunk_net : ['z']}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MIONet(networks=networks,
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
|
||||||
|
def test_forward_extract_str():
|
||||||
|
branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
branch_net2 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
networks = {branch_net1 : ['a'],
|
||||||
|
branch_net2 : ['b'],
|
||||||
|
trunk_net : ['c']}
|
||||||
|
model = MIONet(networks=networks,
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
model(input_)
|
||||||
|
|
||||||
|
def test_forward_extract_int():
|
||||||
|
branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
branch_net2 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
networks = {branch_net1 : [0],
|
||||||
|
branch_net2 : [1],
|
||||||
|
trunk_net : [2]}
|
||||||
|
model = MIONet(networks=networks,
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
model(data)
|
||||||
|
|
||||||
|
def test_forward_extract_str_wrong():
|
||||||
|
branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
branch_net2 = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
networks = {branch_net1 : ['a'],
|
||||||
|
branch_net2 : ['b'],
|
||||||
|
trunk_net : ['c']}
|
||||||
|
model = MIONet(networks=networks,
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
model(data)
|
||||||
Reference in New Issue
Block a user