minor fixes deeponet, remove LabelTensor import

This commit is contained in:
Dario Coscia
2023-09-07 10:47:26 +02:00
committed by Nicola Demo
parent 3a7fc63760
commit 190ee0561d
2 changed files with 39 additions and 8 deletions

View File

@@ -2,14 +2,18 @@
import torch
import torch.nn as nn
from ..utils import check_consistency, is_function
from functools import partial, reduce
from pina import LabelTensor
from functools import partial
class DeepONet(torch.nn.Module):
"""
The PINA implementation of DeepONet network.
DeepONet is a general architecture for learning Operators. Unlike
traditional machine learning methods DeepONet is designed to map
entire functions to other functions. It can be trained both with
Physics Informed or Supervised learning strategies.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
@@ -55,6 +59,15 @@ class DeepONet(torch.nn.Module):
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
reductions.
.. warning::
In the forward pass we do not check if the input is instance of
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
that for a :class:`LabelTensor` input both list of integers and
list of strings can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
only a list of integers can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``.
:Example:
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
@@ -154,20 +167,26 @@ class DeepONet(torch.nn.Module):
def _get_vars(self, x, indeces):
if isinstance(indeces[0], str):
check_consistency(x, LabelTensor)
try:
return x.extract(indeces)
except AttributeError:
raise RuntimeError('Not possible to extract input variables from tensor.'
' Ensure that the passed tensor is a LabelTensor or'
' pass list of integers to extract variables. For'
' more information refer to warning in the documentation.')
elif isinstance(indeces[0], int):
return x[..., indeces]
else:
raise RuntimeError('Not able to extract right indeces for tensor.')
raise RuntimeError('Not able to extract right indeces for tensor.'
' For more information refer to warning in the documentation.')
def forward(self, x):
"""
Defines the computation performed at every call.
:param LabelTensor x: the input tensor.
:return: the output computed by the model.
:rtype: LabelTensor
:param LabelTensor | torch.Tensor x: The input tensor for the forward call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor | torch.Tensor
"""
# forward pass

View File

@@ -53,3 +53,15 @@ def test_forward_extract_int():
reduction='+',
aggregator='*')
model(data)
def test_forward_extract_str_wrong():
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
model = DeepONet(branch_net=branch_net,
trunk_net=trunk_net,
input_indeces_branch_net=['a'],
input_indeces_trunk_net=['b', 'c'],
reduction='+',
aggregator='*')
with pytest.raises(RuntimeError):
model(data)