minor fixes deeponet, remove LabelTensor import
This commit is contained in:
committed by
Nicola Demo
parent
3a7fc63760
commit
190ee0561d
@@ -2,14 +2,18 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ..utils import check_consistency, is_function
|
from ..utils import check_consistency, is_function
|
||||||
from functools import partial, reduce
|
from functools import partial
|
||||||
from pina import LabelTensor
|
|
||||||
|
|
||||||
|
|
||||||
class DeepONet(torch.nn.Module):
|
class DeepONet(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
The PINA implementation of DeepONet network.
|
The PINA implementation of DeepONet network.
|
||||||
|
|
||||||
|
DeepONet is a general architecture for learning Operators. Unlike
|
||||||
|
traditional machine learning methods DeepONet is designed to map
|
||||||
|
entire functions to other functions. It can be trained both with
|
||||||
|
Physics Informed or Supervised learning strategies.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
||||||
@@ -55,6 +59,15 @@ class DeepONet(torch.nn.Module):
|
|||||||
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
|
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
|
||||||
reductions.
|
reductions.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
In the forward pass we do not check if the input is instance of
|
||||||
|
:class:`LabelTensor` or :class:`torch.Tensor`. A general rule is
|
||||||
|
that for a :class:`LabelTensor` input both list of integers and
|
||||||
|
list of strings can be passed for ``input_indeces_branch_net``
|
||||||
|
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
|
||||||
|
only a list of integers can be passed for ``input_indeces_branch_net``
|
||||||
|
and ``input_indeces_trunk_net``.
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
@@ -154,20 +167,26 @@ class DeepONet(torch.nn.Module):
|
|||||||
|
|
||||||
def _get_vars(self, x, indeces):
|
def _get_vars(self, x, indeces):
|
||||||
if isinstance(indeces[0], str):
|
if isinstance(indeces[0], str):
|
||||||
check_consistency(x, LabelTensor)
|
try:
|
||||||
return x.extract(indeces)
|
return x.extract(indeces)
|
||||||
|
except AttributeError:
|
||||||
|
raise RuntimeError('Not possible to extract input variables from tensor.'
|
||||||
|
' Ensure that the passed tensor is a LabelTensor or'
|
||||||
|
' pass list of integers to extract variables. For'
|
||||||
|
' more information refer to warning in the documentation.')
|
||||||
elif isinstance(indeces[0], int):
|
elif isinstance(indeces[0], int):
|
||||||
return x[..., indeces]
|
return x[..., indeces]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('Not able to extract right indeces for tensor.')
|
raise RuntimeError('Not able to extract right indeces for tensor.'
|
||||||
|
' For more information refer to warning in the documentation.')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Defines the computation performed at every call.
|
Defines the computation performed at every call.
|
||||||
|
|
||||||
:param LabelTensor x: the input tensor.
|
:param LabelTensor | torch.Tensor x: The input tensor for the forward call.
|
||||||
:return: the output computed by the model.
|
:return: The output computed by the DeepONet model.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor | torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
|||||||
@@ -53,3 +53,15 @@ def test_forward_extract_int():
|
|||||||
reduction='+',
|
reduction='+',
|
||||||
aggregator='*')
|
aggregator='*')
|
||||||
model(data)
|
model(data)
|
||||||
|
|
||||||
|
def test_forward_extract_str_wrong():
|
||||||
|
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||||
|
model = DeepONet(branch_net=branch_net,
|
||||||
|
trunk_net=trunk_net,
|
||||||
|
input_indeces_branch_net=['a'],
|
||||||
|
input_indeces_trunk_net=['b', 'c'],
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
model(data)
|
||||||
|
|||||||
Reference in New Issue
Block a user