Files
PINA/pina/model/deeponet.py
Dario Coscia b029f18c49 DeepOnet implementation, LabelTensor modification
* Implementing standard DeepOnet (trunk/branch net)
* Implementing multiple reduction/ average techniques
* Small change  LabelTensor __getitem__ for handling list
2023-11-17 09:51:29 +01:00

187 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Module for DeepONet model"""
import torch
import torch.nn as nn
from ..utils import check_consistency, is_function
from functools import partial, reduce
from pina import LabelTensor
class DeepONet(torch.nn.Module):
"""
The PINA implementation of DeepONet network.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
nonlinear operators via DeepONet based on the universal approximation
theorem of operators*. Nat Mach Intell 3, 218229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
"""
def __init__(self,
branch_net,
trunk_net,
input_indeces_branch_net,
input_indeces_trunk_net,
aggregator="*",
reduction="+"):
"""
:param torch.nn.Module branch_net: The neural network to use as branch
model. It has to take as input a :class:`LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output has
to be the same of the ``trunk_net``.
:param torch.nn.Module trunk_net: The neural network to use as trunk
model. It has to take as input a :class:`LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output
has to be the same of the ``branch_net``.
:param list(int) | list(str) input_indeces_branch_net: List of indeces
to extract from the input variable in the forward pass for the
branch net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :class:`LabelTensor` are extracted.
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
to extract from the input variable in the forward pass for the
trunk net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :class:`LabelTensor` are extracted.
:param str | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. See
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
available default aggregators.
:param str | callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
reductions.
:Example:
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> model = DeepONet(branch_net=branch_net,
... trunk_net=trunk_net,
... input_indeces_branch_net=['x'],
... input_indeces_trunk_net=['t'],
... reduction='+',
... aggregator='*')
>>> model
DeepONet(
(trunk_net): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(branch_net): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
)
"""
super().__init__()
# check type consistency
check_consistency(input_indeces_branch_net, (str, int))
check_consistency(input_indeces_trunk_net, (str, int))
check_consistency(branch_net, torch.nn.Module)
check_consistency(trunk_net, torch.nn.Module)
# check trunk branch nets consistency
trunk_out_dim = trunk_net.layers[-1].out_features
branch_out_dim = branch_net.layers[-1].out_features
if trunk_out_dim != branch_out_dim:
raise ValueError('Branch and trunk networks have not the same '
'output dimension.')
# assign trunk and branch net with their input indeces
self.trunk_net = trunk_net
self._trunk_indeces = input_indeces_trunk_net
self.branch_net = branch_net
self._branch_indeces = input_indeces_branch_net
# initializie aggregation
self._init_aggregator(aggregator=aggregator)
self._init_reduction(reduction=reduction)
# scale and translation
self._scale = torch.nn.Parameter(torch.tensor([1.0]))
self._trasl = torch.nn.Parameter(torch.tensor([0.0]))
@staticmethod
def _symbol_functions(**kwargs):
"""
Return a dictionary of functions that can be used as aggregators or
reductions.
"""
return {
"+": partial(torch.sum, **kwargs),
"*": partial(torch.prod, **kwargs),
"mean": partial(torch.mean, **kwargs),
"min": lambda x: torch.min(x, **kwargs).values,
"max": lambda x: torch.max(x, **kwargs).values,
}
def _init_aggregator(self, aggregator):
aggregator_funcs = DeepONet._symbol_functions(dim=2)
if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator]
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
aggregator_func = aggregator
else:
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
self._aggregator = aggregator_func
def _init_reduction(self, reduction):
reduction_funcs = DeepONet._symbol_functions(dim=-1)
if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction]
elif isinstance(reduction, nn.Module) or is_function(reduction):
reduction_func = reduction
else:
raise ValueError(f"Unsupported reduction: {reduction}")
self._reduction = reduction_func
def _get_vars(self, x, indeces):
if isinstance(indeces[0], str):
check_consistency(x, LabelTensor)
return x.extract(indeces)
elif isinstance(indeces[0], int):
return x[..., indeces]
else:
raise RuntimeError('Not able to extract right indeces for tensor.')
def forward(self, x):
"""
Defines the computation performed at every call.
:param LabelTensor x: the input tensor.
:return: the output computed by the model.
:rtype: LabelTensor
"""
# forward pass
branch_output = self.branch_net(self._get_vars(x, self._branch_indeces))
trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces))
# aggregation
aggregated = self._aggregator(torch.dstack((branch_output, trunk_output)))
# reduce
output_ = self._reduction(aggregated).reshape(-1, 1)
# scale and translate
output_ *= self._scale
output_ += self._trasl
return output_