minor fixes deeponet, remove LabelTensor import

This commit is contained in:
Dario Coscia
2023-09-07 10:47:26 +02:00
committed by Nicola Demo
parent 3a7fc63760
commit 190ee0561d
2 changed files with 39 additions and 8 deletions

View File

@@ -2,14 +2,18 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..utils import check_consistency, is_function from ..utils import check_consistency, is_function
from functools import partial, reduce from functools import partial
from pina import LabelTensor
class DeepONet(torch.nn.Module): class DeepONet(torch.nn.Module):
""" """
The PINA implementation of DeepONet network. The PINA implementation of DeepONet network.
DeepONet is a general architecture for learning Operators. Unlike
traditional machine learning methods DeepONet is designed to map
entire functions to other functions. It can be trained both with
Physics Informed or Supervised learning strategies.
.. seealso:: .. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning **Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
@@ -55,6 +59,15 @@ class DeepONet(torch.nn.Module):
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
reductions. reductions.
.. warning::
In the forward pass we do not check if the input is instance of
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
that for a :class:`LabelTensor` input both list of integers and
list of strings can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
only a list of integers can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``.
:Example: :Example:
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
@@ -154,20 +167,26 @@ class DeepONet(torch.nn.Module):
def _get_vars(self, x, indeces): def _get_vars(self, x, indeces):
if isinstance(indeces[0], str): if isinstance(indeces[0], str):
check_consistency(x, LabelTensor) try:
return x.extract(indeces) return x.extract(indeces)
except AttributeError:
raise RuntimeError('Not possible to extract input variables from tensor.'
' Ensure that the passed tensor is a LabelTensor or'
' pass list of integers to extract variables. For'
' more information refer to warning in the documentation.')
elif isinstance(indeces[0], int): elif isinstance(indeces[0], int):
return x[..., indeces] return x[..., indeces]
else: else:
raise RuntimeError('Not able to extract right indeces for tensor.') raise RuntimeError('Not able to extract right indeces for tensor.'
' For more information refer to warning in the documentation.')
def forward(self, x): def forward(self, x):
""" """
Defines the computation performed at every call. Defines the computation performed at every call.
:param LabelTensor x: the input tensor. :param LabelTensor | torch.Tensor x: The input tensor for the forward call.
:return: the output computed by the model. :return: The output computed by the DeepONet model.
:rtype: LabelTensor :rtype: LabelTensor | torch.Tensor
""" """
# forward pass # forward pass

View File

@@ -53,3 +53,15 @@ def test_forward_extract_int():
reduction='+', reduction='+',
aggregator='*') aggregator='*')
model(data) model(data)
def test_forward_extract_str_wrong():
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
model = DeepONet(branch_net=branch_net,
trunk_net=trunk_net,
input_indeces_branch_net=['a'],
input_indeces_trunk_net=['b', 'c'],
reduction='+',
aggregator='*')
with pytest.raises(RuntimeError):
model(data)