From 190ee0561dd65e29a7c89a524fbd450604906f00 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Thu, 7 Sep 2023 10:47:26 +0200 Subject: [PATCH] minor fixes deeponet, remove LabelTensor import --- pina/model/deeponet.py | 35 ++++++++++++++++++++++++------- tests/test_model/test_deeponet.py | 12 +++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index b1bf489..97f29f8 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -2,14 +2,18 @@ import torch import torch.nn as nn from ..utils import check_consistency, is_function -from functools import partial, reduce -from pina import LabelTensor +from functools import partial class DeepONet(torch.nn.Module): """ The PINA implementation of DeepONet network. + DeepONet is a general architecture for learning Operators. Unlike + traditional machine learning methods DeepONet is designed to map + entire functions to other functions. It can be trained both with + Physics Informed or Supervised learning strategies. + .. seealso:: **Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning @@ -55,6 +59,15 @@ class DeepONet(torch.nn.Module): dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default reductions. + .. warning:: + In the forward pass we do not check if the input is instance of + :class:`LabelTensor` or :class:`torch.Tensor`. A general rule is + that for a :class:`LabelTensor` input both list of integers and + list of strings can be passed for ``input_indeces_branch_net`` + and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor` + only a list of integers can be passed for ``input_indeces_branch_net`` + and ``input_indeces_trunk_net``. + :Example: >>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) @@ -154,20 +167,26 @@ class DeepONet(torch.nn.Module): def _get_vars(self, x, indeces): if isinstance(indeces[0], str): - check_consistency(x, LabelTensor) - return x.extract(indeces) + try: + return x.extract(indeces) + except AttributeError: + raise RuntimeError('Not possible to extract input variables from tensor.' + ' Ensure that the passed tensor is a LabelTensor or' + ' pass list of integers to extract variables. For' + ' more information refer to warning in the documentation.') elif isinstance(indeces[0], int): return x[..., indeces] else: - raise RuntimeError('Not able to extract right indeces for tensor.') + raise RuntimeError('Not able to extract right indeces for tensor.' + ' For more information refer to warning in the documentation.') def forward(self, x): """ Defines the computation performed at every call. - :param LabelTensor x: the input tensor. - :return: the output computed by the model. - :rtype: LabelTensor + :param LabelTensor | torch.Tensor x: The input tensor for the forward call. + :return: The output computed by the DeepONet model. + :rtype: LabelTensor | torch.Tensor """ # forward pass diff --git a/tests/test_model/test_deeponet.py b/tests/test_model/test_deeponet.py index b4ed899..fe21f06 100644 --- a/tests/test_model/test_deeponet.py +++ b/tests/test_model/test_deeponet.py @@ -53,3 +53,15 @@ def test_forward_extract_int(): reduction='+', aggregator='*') model(data) + +def test_forward_extract_str_wrong(): + branch_net = FeedForward(input_dimensons=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=2, output_dimensions=10) + model = DeepONet(branch_net=branch_net, + trunk_net=trunk_net, + input_indeces_branch_net=['a'], + input_indeces_trunk_net=['b', 'c'], + reduction='+', + aggregator='*') + with pytest.raises(RuntimeError): + model(data)