From ba7371f350684725ec7598d3781bd7a4650a54d6 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Wed, 13 Sep 2023 09:46:52 +0200 Subject: [PATCH] MIONet addition and DeepONet modification. * Create MIONet base class * DeepONet inherits from MIONet --- pina/model/__init__.py | 3 +- pina/model/deeponet.py | 326 ++++++++++++++++++++++++-------- tests/test_model/test_mionet.py | 72 +++++++ 3 files changed, 322 insertions(+), 79 deletions(-) create mode 100644 tests/test_model/test_mionet.py diff --git a/pina/model/__init__.py b/pina/model/__init__.py index a1ea5b2..827cf0a 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -2,10 +2,11 @@ __all__ = [ 'FeedForward', 'MultiFeedForward', 'DeepONet', + 'MIONet', 'FNO', ] from .feed_forward import FeedForward from .multi_feed_forward import MultiFeedForward -from .deeponet import DeepONet +from .deeponet import DeepONet, MIONet from .fno import FNO diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index 26e5bdf..454724c 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -5,59 +5,53 @@ from ..utils import check_consistency, is_function from functools import partial -class DeepONet(torch.nn.Module): +class MIONet(torch.nn.Module): """ - The PINA implementation of DeepONet network. + The PINA implementation of MIONet network. - DeepONet is a general architecture for learning Operators. Unlike - traditional machine learning methods DeepONet is designed to map - entire functions to other functions. It can be trained both with - Physics Informed or Supervised learning strategies. + MIONet is a general architecture for learning Operators defined + on the tensor product of Banach spaces. Unlike traditional machine + learning methods MIONet is designed to map entire functions to other functions. + It can be trained both with Physics Informed or Supervised learning strategies. .. seealso:: - **Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning - nonlinear operators via DeepONet based on the universal approximation - theorem of operators*. Nat Mach Intell 3, 218–229 (2021). - DOI: `10.1038/s42256-021-00302-5 - `_ + **Original reference**: Jin, Pengzhan, Shuai Meng, and Lu Lu. + "MIONet: Learning multiple-input operators via tensor product." + SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351 + DOI: `10.1137/22M1477751 + `_ """ def __init__(self, - branch_net, - trunk_net, - input_indeces_branch_net, - input_indeces_trunk_net, + networks, aggregator="*", - reduction="+"): + reduction="+", + scale=True, + translation=True): """ - :param torch.nn.Module branch_net: The neural network to use as branch - model. It has to take as input a :class:`LabelTensor` - or :class:`torch.Tensor`. The number of dimensions of the output has - to be the same of the ``trunk_net``. - :param torch.nn.Module trunk_net: The neural network to use as trunk - model. It has to take as input a :class:`LabelTensor` - or :class:`torch.Tensor`. The number of dimensions of the output - has to be the same of the ``branch_net``. - :param list(int) | list(str) input_indeces_branch_net: List of indeces - to extract from the input variable in the forward pass for the - branch net. If a list of ``int`` is passed, the corresponding columns - of the inner most entries are extracted. If a list of ``str`` is passed - the variables of the corresponding :class:`LabelTensor` are extracted. - :param list(int) | list(str) input_indeces_trunk_net: List of indeces - to extract from the input variable in the forward pass for the - trunk net. If a list of ``int`` is passed, the corresponding columns - of the inner most entries are extracted. If a list of ``str`` is passed - the variables of the corresponding :class:`LabelTensor` are extracted. + :param dict networks: The neural networks to use as + models. The ``dict`` takes as key a neural network, and + as value the list of indeces to extract from the input variable + in the forward pass of the neural network. If a list of ``int`` is passed, + the corresponding columns of the inner most entries are extracted. + If a list of ``str`` is passed the variables of the corresponding :class:`LabelTensor` + are extracted. The ``torch.nn.Module`` model has to take as input a + :class:`LabelTensor` or :class:`torch.Tensor`. Default implementation consist of different + branch nets and one trunk net. :param str | callable aggregator: Aggregator to be used to aggregate partial results from the modules in `nets`. Partial results are aggregated component-wise. See - :func:`pina.model.deeponet.DeepONet._symbol_functions` for the + :func:`pina.model.deeponet.MIONet._symbol_functions` for the available default aggregators. :param str | callable reduction: Reduction to be used to reduce the aggregated result of the modules in `nets` to the desired output - dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default - reductions. + dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions` + for the available default reductions. + :param bool | callable scale: Scaling the final output before returning the + forward pass, default True. + :param bool | callable translation: Translating the final output before + returning the forward pass, default True. .. warning:: In the forward pass we do not check if the input is instance of @@ -69,32 +63,44 @@ class DeepONet(torch.nn.Module): and ``input_indeces_trunk_net``. :Example: - >>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10) + >>> branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10) + >>> branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) - >>> model = DeepONet(branch_net=branch_net, - ... trunk_net=trunk_net, - ... input_indeces_branch_net=['x'], - ... input_indeces_trunk_net=['t'], - ... reduction='+', - ... aggregator='*') + >>> networks = {branch_net1 : ['x'], + branch_net2 : ['x', 'y'], + ... trunk_net : ['z']} + >>> model = MIONet(networks=networks, + ... reduction='+', + ... aggregator='*') >>> model - DeepONet( - (trunk_net): FeedForward( + MIONet( + (models): ModuleList( + (0): FeedForward( (model): Sequential( - (0): Linear(in_features=1, out_features=20, bias=True) - (1): Tanh() - (2): Linear(in_features=20, out_features=20, bias=True) - (3): Tanh() - (4): Linear(in_features=20, out_features=10, bias=True) + (0): Linear(in_features=1, out_features=20, bias=True) + (1): Tanh() + (2): Linear(in_features=20, out_features=20, bias=True) + (3): Tanh() + (4): Linear(in_features=20, out_features=10, bias=True) ) - ) - (branch_net): FeedForward( + ) + (1): FeedForward( (model): Sequential( - (0): Linear(in_features=1, out_features=20, bias=True) - (1): Tanh() - (2): Linear(in_features=20, out_features=20, bias=True) - (3): Tanh() - (4): Linear(in_features=20, out_features=10, bias=True) + (0): Linear(in_features=2, out_features=20, bias=True) + (1): Tanh() + (2): Linear(in_features=20, out_features=20, bias=True) + (3): Tanh() + (4): Linear(in_features=20, out_features=10, bias=True) + ) + ) + (2): FeedForward( + (model): Sequential( + (0): Linear(in_features=1, out_features=20, bias=True) + (1): Tanh() + (2): Linear(in_features=20, out_features=20, bias=True) + (3): Tanh() + (4): Linear(in_features=20, out_features=10, bias=True) + ) ) ) ) @@ -102,33 +108,33 @@ class DeepONet(torch.nn.Module): super().__init__() # check type consistency - check_consistency(input_indeces_branch_net, (str, int)) - check_consistency(input_indeces_trunk_net, (str, int)) - check_consistency(branch_net, torch.nn.Module) - check_consistency(trunk_net, torch.nn.Module) + check_consistency(networks, dict) + check_consistency(scale, bool) + check_consistency(translation, bool) # check trunk branch nets consistency - input_trunk = torch.rand(10, len(input_indeces_trunk_net)) - input_branch = torch.rand(10, len(input_indeces_branch_net)) - trunk_out_dim = trunk_net(input_trunk).shape[-1] - branch_out_dim = branch_net(input_branch).shape[-1] - if trunk_out_dim != branch_out_dim: - raise ValueError('Branch and trunk networks have not the same ' + shapes = [] + for key, value in networks.items(): + check_consistency(value, (str, int)) + check_consistency(key, torch.nn.Module) + input_ = torch.rand(10, len(value)) + shapes.append(key(input_).shape[-1]) + + if not all(map(lambda x: x == shapes[0], shapes)): + raise ValueError('The passed networks have not the same ' 'output dimension.') # assign trunk and branch net with their input indeces - self.trunk_net = trunk_net - self._trunk_indeces = input_indeces_trunk_net - self.branch_net = branch_net - self._branch_indeces = input_indeces_branch_net + self.models = torch.nn.ModuleList(networks.keys()) + self._indeces = networks.values() # initializie aggregation self._init_aggregator(aggregator=aggregator) self._init_reduction(reduction=reduction) # scale and translation - self._scale = torch.nn.Parameter(torch.tensor([1.0])) - self._trasl = torch.nn.Parameter(torch.tensor([0.0])) + self._scale = torch.nn.Parameter(torch.tensor([1.0])) if scale else torch.tensor([1.0]) + self._trasl = torch.nn.Parameter(torch.tensor([1.0])) if translation else torch.tensor([1.0]) @staticmethod def _symbol_functions(**kwargs): @@ -192,11 +198,10 @@ class DeepONet(torch.nn.Module): """ # forward pass - branch_output = self.branch_net(self._get_vars(x, self._branch_indeces)) - trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces)) + output_ = [model(self._get_vars(x, indeces)) for model, indeces in zip(self.models, self._indeces)] # aggregation - aggregated = self._aggregator(torch.dstack((branch_output, trunk_output))) + aggregated = self._aggregator(torch.dstack(output_)) # reduce output_ = self._reduction(aggregated).reshape(-1, 1) @@ -205,4 +210,169 @@ class DeepONet(torch.nn.Module): output_ *= self._scale output_ += self._trasl - return output_ \ No newline at end of file + return output_ + + @property + def aggregator(self): + """ + The aggregator function. + """ + return self._aggregator + + @property + def reduction(self): + """ + The translation factor. + """ + return self._reduction + + @property + def scale(self): + """ + The scale factor. + """ + return self._scale + + @property + def translation(self): + """ + The translation factor for MIONet. + """ + return self._trasl + + @property + def indeces_variables_extracted(self): + """ + The input indeces for each model in form of list. + """ + return self._indeces + + @property + def model(self): + """ + The models in form of list. + """ + return self._indeces + + +class DeepONet(MIONet): + """ + The PINA implementation of DeepONet network. + + DeepONet is a general architecture for learning Operators. Unlike + traditional machine learning methods DeepONet is designed to map + entire functions to other functions. It can be trained both with + Physics Informed or Supervised learning strategies. + + .. seealso:: + + **Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning + nonlinear operators via DeepONet based on the universal approximation + theorem of operators*. Nat Mach Intell 3, 218–229 (2021). + DOI: `10.1038/s42256-021-00302-5 + `_ + + """ + def __init__(self, + branch_net, + trunk_net, + input_indeces_branch_net, + input_indeces_trunk_net, + aggregator="*", + reduction="+", + scale=True, + translation=True): + """ + :param torch.nn.Module branch_net: The neural network to use as branch + model. It has to take as input a :class:`LabelTensor` + or :class:`torch.Tensor`. The number of dimensions of the output has + to be the same of the ``trunk_net``. + :param torch.nn.Module trunk_net: The neural network to use as trunk + model. It has to take as input a :class:`LabelTensor` + or :class:`torch.Tensor`. The number of dimensions of the output + has to be the same of the ``branch_net``. + :param list(int) | list(str) input_indeces_branch_net: List of indeces + to extract from the input variable in the forward pass for the + branch net. If a list of ``int`` is passed, the corresponding columns + of the inner most entries are extracted. If a list of ``str`` is passed + the variables of the corresponding :class:`LabelTensor` are extracted. + :param list(int) | list(str) input_indeces_trunk_net: List of indeces + to extract from the input variable in the forward pass for the + trunk net. If a list of ``int`` is passed, the corresponding columns + of the inner most entries are extracted. If a list of ``str`` is passed + the variables of the corresponding :class:`LabelTensor` are extracted. + :param str | callable aggregator: Aggregator to be used to aggregate + partial results from the modules in `nets`. Partial results are + aggregated component-wise. See + :func:`pina.model.deeponet.MIONet._symbol_functions` for the + available default aggregators. + :param str | callable reduction: Reduction to be used to reduce + the aggregated result of the modules in `nets` to the desired output + dimension. See :py:obj:`pina.model.deeponet.MIONet._symbol_functions` for the available default + reductions. + :param bool | callable scale: Scaling the final output before returning the + forward pass, default True. + :param bool | callable translation: Translating the final output before + returning the forward pass, default True. + + .. warning:: + In the forward pass we do not check if the input is instance of + :class:`LabelTensor` or :class:`torch.Tensor`. A general rule is + that for a :class:`LabelTensor` input both list of integers and + list of strings can be passed for ``input_indeces_branch_net`` + and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor` + only a list of integers can be passed for ``input_indeces_branch_net`` + and ``input_indeces_trunk_net``. + + :Example: + >>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10) + >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) + >>> model = DeepONet(branch_net=branch_net, + ... trunk_net=trunk_net, + ... input_indeces_branch_net=['x'], + ... input_indeces_trunk_net=['t'], + ... reduction='+', + ... aggregator='*') + >>> model + DeepONet( + (trunk_net): FeedForward( + (model): Sequential( + (0): Linear(in_features=1, out_features=20, bias=True) + (1): Tanh() + (2): Linear(in_features=20, out_features=20, bias=True) + (3): Tanh() + (4): Linear(in_features=20, out_features=10, bias=True) + ) + ) + (branch_net): FeedForward( + (model): Sequential( + (0): Linear(in_features=1, out_features=20, bias=True) + (1): Tanh() + (2): Linear(in_features=20, out_features=20, bias=True) + (3): Tanh() + (4): Linear(in_features=20, out_features=10, bias=True) + ) + ) + ) + """ + networks = {branch_net : input_indeces_branch_net, + trunk_net : input_indeces_trunk_net} + super().__init__(networks=networks, + aggregator=aggregator, + reduction=reduction, + scale=scale, + translation=translation) + + @property + def branch_net(self): + """ + The branch net for DeepONet. + """ + return self.models[0] + + @property + def trunk_net(self): + """ + The trunk net for DeepONet. + """ + return self.models[1] \ No newline at end of file diff --git a/tests/test_model/test_mionet.py b/tests/test_model/test_mionet.py new file mode 100644 index 0000000..5150485 --- /dev/null +++ b/tests/test_model/test_mionet.py @@ -0,0 +1,72 @@ +import pytest +import torch + +from pina import LabelTensor +from pina.model import MIONet +from pina.model import FeedForward + +data = torch.rand((20, 3)) +input_vars = ['a', 'b', 'c'] +input_ = LabelTensor(data, input_vars) + + +def test_constructor(): + branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10) + branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) + networks = {branch_net1 : ['x'], + branch_net2 : ['x', 'y'], + trunk_net : ['z']} + MIONet(networks=networks, + reduction='+', + aggregator='*') + + +def test_constructor_fails_when_invalid_inner_layer_size(): + branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10) + branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=1, output_dimensions=12) + networks = {branch_net1 : ['x'], + branch_net2 : ['x', 'y'], + trunk_net : ['z']} + with pytest.raises(ValueError): + MIONet(networks=networks, + reduction='+', + aggregator='*') + +def test_forward_extract_str(): + branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10) + branch_net2 = FeedForward(input_dimensons=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) + networks = {branch_net1 : ['a'], + branch_net2 : ['b'], + trunk_net : ['c']} + model = MIONet(networks=networks, + reduction='+', + aggregator='*') + model(input_) + +def test_forward_extract_int(): + branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10) + branch_net2 = FeedForward(input_dimensons=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) + networks = {branch_net1 : [0], + branch_net2 : [1], + trunk_net : [2]} + model = MIONet(networks=networks, + reduction='+', + aggregator='*') + model(data) + +def test_forward_extract_str_wrong(): + branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10) + branch_net2 = FeedForward(input_dimensons=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) + networks = {branch_net1 : ['a'], + branch_net2 : ['b'], + trunk_net : ['c']} + model = MIONet(networks=networks, + reduction='+', + aggregator='*') + with pytest.raises(RuntimeError): + model(data)