"""Module for DeepONet model""" import torch import torch.nn as nn from ..utils import check_consistency, is_function from functools import partial, reduce from pina import LabelTensor class DeepONet(torch.nn.Module): """ The PINA implementation of DeepONet network. .. seealso:: **Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators*. Nat Mach Intell 3, 218–229 (2021). DOI: `10.1038/s42256-021-00302-5 `_ """ def __init__(self, branch_net, trunk_net, input_indeces_branch_net, input_indeces_trunk_net, aggregator="*", reduction="+"): """ :param torch.nn.Module branch_net: The neural network to use as branch model. It has to take as input a :class:`LabelTensor` or :class:`torch.Tensor`. The number of dimensions of the output has to be the same of the ``trunk_net``. :param torch.nn.Module trunk_net: The neural network to use as trunk model. It has to take as input a :class:`LabelTensor` or :class:`torch.Tensor`. The number of dimensions of the output has to be the same of the ``branch_net``. :param list(int) | list(str) input_indeces_branch_net: List of indeces to extract from the input variable in the forward pass for the branch net. If a list of ``int`` is passed, the corresponding columns of the inner most entries are extracted. If a list of ``str`` is passed the variables of the corresponding :class:`LabelTensor` are extracted. :param list(int) | list(str) input_indeces_trunk_net: List of indeces to extract from the input variable in the forward pass for the trunk net. If a list of ``int`` is passed, the corresponding columns of the inner most entries are extracted. If a list of ``str`` is passed the variables of the corresponding :class:`LabelTensor` are extracted. :param str | callable aggregator: Aggregator to be used to aggregate partial results from the modules in `nets`. Partial results are aggregated component-wise. See :func:`pina.model.deeponet.DeepONet._symbol_functions` for the available default aggregators. :param str | callable reduction: Reduction to be used to reduce the aggregated result of the modules in `nets` to the desired output dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default reductions. :Example: >>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> model = DeepONet(branch_net=branch_net, ... trunk_net=trunk_net, ... input_indeces_branch_net=['x'], ... input_indeces_trunk_net=['t'], ... reduction='+', ... aggregator='*') >>> model DeepONet( (trunk_net): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) (branch_net): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) ) """ super().__init__() # check type consistency check_consistency(input_indeces_branch_net, (str, int)) check_consistency(input_indeces_trunk_net, (str, int)) check_consistency(branch_net, torch.nn.Module) check_consistency(trunk_net, torch.nn.Module) # check trunk branch nets consistency trunk_out_dim = trunk_net.layers[-1].out_features branch_out_dim = branch_net.layers[-1].out_features if trunk_out_dim != branch_out_dim: raise ValueError('Branch and trunk networks have not the same ' 'output dimension.') # assign trunk and branch net with their input indeces self.trunk_net = trunk_net self._trunk_indeces = input_indeces_trunk_net self.branch_net = branch_net self._branch_indeces = input_indeces_branch_net # initializie aggregation self._init_aggregator(aggregator=aggregator) self._init_reduction(reduction=reduction) # scale and translation self._scale = torch.nn.Parameter(torch.tensor([1.0])) self._trasl = torch.nn.Parameter(torch.tensor([0.0])) @staticmethod def _symbol_functions(**kwargs): """ Return a dictionary of functions that can be used as aggregators or reductions. """ return { "+": partial(torch.sum, **kwargs), "*": partial(torch.prod, **kwargs), "mean": partial(torch.mean, **kwargs), "min": lambda x: torch.min(x, **kwargs).values, "max": lambda x: torch.max(x, **kwargs).values, } def _init_aggregator(self, aggregator): aggregator_funcs = DeepONet._symbol_functions(dim=2) if aggregator in aggregator_funcs: aggregator_func = aggregator_funcs[aggregator] elif isinstance(aggregator, nn.Module) or is_function(aggregator): aggregator_func = aggregator else: raise ValueError(f"Unsupported aggregation: {str(aggregator)}") self._aggregator = aggregator_func def _init_reduction(self, reduction): reduction_funcs = DeepONet._symbol_functions(dim=-1) if reduction in reduction_funcs: reduction_func = reduction_funcs[reduction] elif isinstance(reduction, nn.Module) or is_function(reduction): reduction_func = reduction else: raise ValueError(f"Unsupported reduction: {reduction}") self._reduction = reduction_func def _get_vars(self, x, indeces): if isinstance(indeces[0], str): check_consistency(x, LabelTensor) return x.extract(indeces) elif isinstance(indeces[0], int): return x[..., indeces] else: raise RuntimeError('Not able to extract right indeces for tensor.') def forward(self, x): """ Defines the computation performed at every call. :param LabelTensor x: the input tensor. :return: the output computed by the model. :rtype: LabelTensor """ # forward pass branch_output = self.branch_net(self._get_vars(x, self._branch_indeces)) trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces)) # aggregation aggregated = self._aggregator(torch.dstack((branch_output, trunk_output))) # reduce output_ = self._reduction(aggregated).reshape(-1, 1) # scale and translate output_ *= self._scale output_ += self._trasl return output_