minor fixes deeponet, remove LabelTensor import
This commit is contained in:
committed by
Nicola Demo
parent
3a7fc63760
commit
190ee0561d
@@ -2,14 +2,18 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..utils import check_consistency, is_function
|
||||
from functools import partial, reduce
|
||||
from pina import LabelTensor
|
||||
from functools import partial
|
||||
|
||||
|
||||
class DeepONet(torch.nn.Module):
|
||||
"""
|
||||
The PINA implementation of DeepONet network.
|
||||
|
||||
DeepONet is a general architecture for learning Operators. Unlike
|
||||
traditional machine learning methods DeepONet is designed to map
|
||||
entire functions to other functions. It can be trained both with
|
||||
Physics Informed or Supervised learning strategies.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
||||
@@ -55,6 +59,15 @@ class DeepONet(torch.nn.Module):
|
||||
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
|
||||
reductions.
|
||||
|
||||
.. warning::
|
||||
In the forward pass we do not check if the input is instance of
|
||||
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
|
||||
that for a :class:`LabelTensor` input both list of integers and
|
||||
list of strings can be passed for ``input_indeces_branch_net``
|
||||
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
|
||||
only a list of integers can be passed for ``input_indeces_branch_net``
|
||||
and ``input_indeces_trunk_net``.
|
||||
|
||||
:Example:
|
||||
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
@@ -154,20 +167,26 @@ class DeepONet(torch.nn.Module):
|
||||
|
||||
def _get_vars(self, x, indeces):
|
||||
if isinstance(indeces[0], str):
|
||||
check_consistency(x, LabelTensor)
|
||||
return x.extract(indeces)
|
||||
try:
|
||||
return x.extract(indeces)
|
||||
except AttributeError:
|
||||
raise RuntimeError('Not possible to extract input variables from tensor.'
|
||||
' Ensure that the passed tensor is a LabelTensor or'
|
||||
' pass list of integers to extract variables. For'
|
||||
' more information refer to warning in the documentation.')
|
||||
elif isinstance(indeces[0], int):
|
||||
return x[..., indeces]
|
||||
else:
|
||||
raise RuntimeError('Not able to extract right indeces for tensor.')
|
||||
raise RuntimeError('Not able to extract right indeces for tensor.'
|
||||
' For more information refer to warning in the documentation.')
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
|
||||
:param LabelTensor x: the input tensor.
|
||||
:return: the output computed by the model.
|
||||
:rtype: LabelTensor
|
||||
:param LabelTensor | torch.Tensor x: The input tensor for the forward call.
|
||||
:return: The output computed by the DeepONet model.
|
||||
:rtype: LabelTensor | torch.Tensor
|
||||
"""
|
||||
|
||||
# forward pass
|
||||
|
||||
@@ -53,3 +53,15 @@ def test_forward_extract_int():
|
||||
reduction='+',
|
||||
aggregator='*')
|
||||
model(data)
|
||||
|
||||
def test_forward_extract_str_wrong():
|
||||
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||
model = DeepONet(branch_net=branch_net,
|
||||
trunk_net=trunk_net,
|
||||
input_indeces_branch_net=['a'],
|
||||
input_indeces_trunk_net=['b', 'c'],
|
||||
reduction='+',
|
||||
aggregator='*')
|
||||
with pytest.raises(RuntimeError):
|
||||
model(data)
|
||||
|
||||
Reference in New Issue
Block a user