"""Module for DeepONet model""" import logging from functools import partial, reduce import torch import torch.nn as nn from pina import LabelTensor from pina.model import FeedForward from pina.utils import is_function def check_combos(combos, variables): """ Check that the given combinations are subsets (overlapping is allowed) of the given set of variables. :param iterable(iterable(str)) combos: Combinations of variables. :param iterable(str) variables: Variables. """ for combo in combos: for variable in combo: if variable not in variables: raise ValueError( f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable" ) def spawn_combo_networks( combos, layers, output_dimension, func, extra_feature, bias=True ): """ Spawn internal networks for DeepONet based on the given combos. :param iterable(iterable(str)) combos: Combinations of variables. :param iterable(int) layers: Size of hidden layers. :param int output_dimension: Size of the output layer of the networks. :param func: Nonlinearity. :param extra_feature: Extra feature to be considered by the networks. :param bool bias: Whether to consider bias or not. """ if is_function(extra_feature): extra_feature_func = lambda _: extra_feature else: extra_feature_func = extra_feature return [ FeedForward( layers=layers, input_variables=tuple(combo), output_variables=output_dimension, func=func, extra_features=extra_feature_func(combo), bias=bias, ) for combo in combos ] class DeepONet(torch.nn.Module): """ The PINA implementation of DeepONet network. .. seealso:: **Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators*. Nat Mach Intell 3, 218–229 (2021). DOI: `10.1038/s42256-021-00302-5 `_ """ def __init__(self, nets, output_variables, aggregator="*", reduction="+"): """ :param iterable(torch.nn.Module) nets: Internal DeepONet networks (branch and trunk in the original DeepONet). :param list(str) output_variables: the list containing the labels corresponding to the components of the output computed by the model. :param str | callable aggregator: Aggregator to be used to aggregate partial results from the modules in `nets`. Partial results are aggregated component-wise. See :func:`pina.model.deeponet.DeepONet._symbol_functions` for the available default aggregators. :param str | callable reduction: Reduction to be used to reduce the aggregated result of the modules in `nets` to the desired output dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default reductions. :Example: >>> branch = FFN(input_variables=['a', 'c'], output_variables=20) >>> trunk = FFN(input_variables=['b'], output_variables=20) >>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) DeepONet( (trunk_net): FeedForward( (extra_features): Sequential() (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=20, bias=True) ) ) (branch_net): FeedForward( (extra_features): Sequential() (model): Sequential( (0): Linear(in_features=2, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=20, bias=True) ) ) ) """ super().__init__() self.output_variables = output_variables self.output_dimension = len(output_variables) self._init_aggregator(aggregator, n_nets=len(nets)) hidden_size = nets[0].model[-1].out_features self._init_reduction(reduction, hidden_size=hidden_size) if not DeepONet._all_nets_same_output_layer_size(nets): raise ValueError("All networks should have the same output size") self._nets = torch.nn.ModuleList(nets) logging.info("Combo DeepONet children: %s", list(self.children())) self.scale = torch.nn.Parameter(torch.tensor([1.0])) self.trasl = torch.nn.Parameter(torch.tensor([0.0])) @staticmethod def _symbol_functions(**kwargs): """ Return a dictionary of functions that can be used as aggregators or reductions. """ return { "+": partial(torch.sum, **kwargs), "*": partial(torch.prod, **kwargs), "mean": partial(torch.mean, **kwargs), "min": lambda x: torch.min(x, **kwargs).values, "max": lambda x: torch.max(x, **kwargs).values, } def _init_aggregator(self, aggregator, n_nets): aggregator_funcs = DeepONet._symbol_functions(dim=2) if aggregator in aggregator_funcs: aggregator_func = aggregator_funcs[aggregator] elif isinstance(aggregator, nn.Module) or is_function(aggregator): aggregator_func = aggregator elif aggregator == "linear": aggregator_func = nn.Linear(n_nets, len(self.output_variables)) else: raise ValueError(f"Unsupported aggregation: {str(aggregator)}") self._aggregator = aggregator_func logging.info("Selected aggregator: %s", str(aggregator_func)) # test the aggregator test = self._aggregator(torch.ones((20, 3, n_nets))) if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3): raise ValueError( f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}" ) def _init_reduction(self, reduction, hidden_size): reduction_funcs = DeepONet._symbol_functions(dim=2) if reduction in reduction_funcs: reduction_func = reduction_funcs[reduction] elif isinstance(reduction, nn.Module) or is_function(reduction): reduction_func = reduction elif reduction == "linear": reduction_func = nn.Linear(hidden_size, len(self.output_variables)) else: raise ValueError(f"Unsupported reduction: {reduction}") self._reduction = reduction_func logging.info("Selected reduction: %s", str(reduction)) # test the reduction test = self._reduction(torch.ones((20, 3, hidden_size))) if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3): msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}" raise ValueError(msg) @staticmethod def _all_nets_same_output_layer_size(nets): size = nets[0].layers[-1].out_features return all((net.layers[-1].out_features == size for net in nets[1:])) @property def input_variables(self): """The input variables of the model""" nets_input_variables = map(lambda net: net.input_variables, self._nets) return reduce(sum, nets_input_variables) def forward(self, x): """ Defines the computation performed at every call. :param LabelTensor x: the input tensor. :return: the output computed by the model. :rtype: LabelTensor """ nets_outputs = tuple( net(x.extract(net.input_variables)) for net in self._nets ) # torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets) aggregated = self._aggregator(torch.dstack(nets_outputs)) # net_output_size = output_variables * hidden_size aggregated_reshaped = aggregated.view( (len(x), len(self.output_variables), -1) ) output_ = self._reduction(aggregated_reshaped) output_ = torch.squeeze(torch.atleast_3d(output_), dim=2) assert output_.shape == (len(x), len(self.output_variables)) output_ = output_.as_subclass(LabelTensor) output_.labels = self.output_variables output_ *= self.scale output_ += self.trasl return output_