From b029f18c49c509ba38b349a482c0eee3e411f588 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Wed, 6 Sep 2023 12:40:21 +0200 Subject: [PATCH] DeepOnet implementation, LabelTensor modification * Implementing standard DeepOnet (trunk/branch net) * Implementing multiple reduction/ average techniques * Small change LabelTensor __getitem__ for handling list --- pina/label_tensor.py | 7 +- pina/model/deeponet.py | 227 ++++++++++++------------------ tests/test_model/test_deeponet.py | 60 +++++--- 3 files changed, 140 insertions(+), 154 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index c9e03bf..c469995 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -212,8 +212,11 @@ class LabelTensor(torch.Tensor): if selected_lt.ndim == 1: selected_lt = selected_lt.reshape(-1, 1) if hasattr(self, 'labels'): - selected_lt.labels = self.labels[index[1]] - + if isinstance(index[1], list): + selected_lt.labels = [self.labels[i] for i in index[1]] + else: + selected_lt.labels = self.labels[index[1]] + return selected_lt def __len__(self) -> int: diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index 3bedd82..b1bf489 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -1,60 +1,9 @@ """Module for DeepONet model""" -import logging -from functools import partial, reduce - import torch import torch.nn as nn - +from ..utils import check_consistency, is_function +from functools import partial, reduce from pina import LabelTensor -from pina.model import FeedForward -from pina.utils import is_function - - -def check_combos(combos, variables): - """ - Check that the given combinations are subsets (overlapping - is allowed) of the given set of variables. - - :param iterable(iterable(str)) combos: Combinations of variables. - :param iterable(str) variables: Variables. - """ - for combo in combos: - for variable in combo: - if variable not in variables: - raise ValueError( - f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable" - ) - - -def spawn_combo_networks( - combos, layers, output_dimension, func, extra_feature, bias=True -): - """ - Spawn internal networks for DeepONet based on the given combos. - - :param iterable(iterable(str)) combos: Combinations of variables. - :param iterable(int) layers: Size of hidden layers. - :param int output_dimension: Size of the output layer of the networks. - :param func: Nonlinearity. - :param extra_feature: Extra feature to be considered by the networks. - :param bool bias: Whether to consider bias or not. - """ - if is_function(extra_feature): - extra_feature_func = lambda _: extra_feature - else: - extra_feature_func = extra_feature - - return [ - FeedForward( - layers=layers, - input_variables=tuple(combo), - output_variables=output_dimension, - func=func, - extra_features=extra_feature_func(combo), - bias=bias, - ) - for combo in combos - ] class DeepONet(torch.nn.Module): @@ -70,14 +19,32 @@ class DeepONet(torch.nn.Module): `_ """ - - def __init__(self, nets, output_variables, aggregator="*", reduction="+"): + def __init__(self, + branch_net, + trunk_net, + input_indeces_branch_net, + input_indeces_trunk_net, + aggregator="*", + reduction="+"): """ - :param iterable(torch.nn.Module) nets: Internal DeepONet networks - (branch and trunk in the original DeepONet). - :param list(str) output_variables: the list containing the labels - corresponding to the components of the output computed by the - model. + :param torch.nn.Module branch_net: The neural network to use as branch + model. It has to take as input a :class:`LabelTensor` + or :class:`torch.Tensor`. The number of dimensions of the output has + to be the same of the ``trunk_net``. + :param torch.nn.Module trunk_net: The neural network to use as trunk + model. It has to take as input a :class:`LabelTensor` + or :class:`torch.Tensor`. The number of dimensions of the output + has to be the same of the ``branch_net``. + :param list(int) | list(str) input_indeces_branch_net: List of indeces + to extract from the input variable in the forward pass for the + branch net. If a list of ``int`` is passed, the corresponding columns + of the inner most entries are extracted. If a list of ``str`` is passed + the variables of the corresponding :class:`LabelTensor` are extracted. + :param list(int) | list(str) input_indeces_trunk_net: List of indeces + to extract from the input variable in the forward pass for the + trunk net. If a list of ``int`` is passed, the corresponding columns + of the inner most entries are extracted. If a list of ``str`` is passed + the variables of the corresponding :class:`LabelTensor` are extracted. :param str | callable aggregator: Aggregator to be used to aggregate partial results from the modules in `nets`. Partial results are aggregated component-wise. See @@ -87,49 +54,66 @@ class DeepONet(torch.nn.Module): the aggregated result of the modules in `nets` to the desired output dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default reductions. - + :Example: - >>> branch = FFN(input_variables=['a', 'c'], output_variables=20) - >>> trunk = FFN(input_variables=['b'], output_variables=20) - >>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) + >>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10) + >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) + >>> model = DeepONet(branch_net=branch_net, + ... trunk_net=trunk_net, + ... input_indeces_branch_net=['x'], + ... input_indeces_trunk_net=['t'], + ... reduction='+', + ... aggregator='*') + >>> model DeepONet( - (trunk_net): FeedForward( - (extra_features): Sequential() + (trunk_net): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() - (4): Linear(in_features=20, out_features=20, bias=True) + (4): Linear(in_features=20, out_features=10, bias=True) ) - ) - (branch_net): FeedForward( - (extra_features): Sequential() + ) + (branch_net): FeedForward( (model): Sequential( - (0): Linear(in_features=2, out_features=20, bias=True) + (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() - (4): Linear(in_features=20, out_features=20, bias=True) + (4): Linear(in_features=20, out_features=10, bias=True) ) - ) + ) ) """ super().__init__() - self.output_variables = output_variables - self.output_dimension = len(output_variables) + # check type consistency + check_consistency(input_indeces_branch_net, (str, int)) + check_consistency(input_indeces_trunk_net, (str, int)) + check_consistency(branch_net, torch.nn.Module) + check_consistency(trunk_net, torch.nn.Module) - self._init_aggregator(aggregator, n_nets=len(nets)) - hidden_size = nets[0].model[-1].out_features - self._init_reduction(reduction, hidden_size=hidden_size) + # check trunk branch nets consistency + trunk_out_dim = trunk_net.layers[-1].out_features + branch_out_dim = branch_net.layers[-1].out_features + if trunk_out_dim != branch_out_dim: + raise ValueError('Branch and trunk networks have not the same ' + 'output dimension.') - if not DeepONet._all_nets_same_output_layer_size(nets): - raise ValueError("All networks should have the same output size") - self._nets = torch.nn.ModuleList(nets) - logging.info("Combo DeepONet children: %s", list(self.children())) - self.scale = torch.nn.Parameter(torch.tensor([1.0])) - self.trasl = torch.nn.Parameter(torch.tensor([0.0])) + # assign trunk and branch net with their input indeces + self.trunk_net = trunk_net + self._trunk_indeces = input_indeces_trunk_net + self.branch_net = branch_net + self._branch_indeces = input_indeces_branch_net + + # initializie aggregation + self._init_aggregator(aggregator=aggregator) + self._init_reduction(reduction=reduction) + + # scale and translation + self._scale = torch.nn.Parameter(torch.tensor([1.0])) + self._trasl = torch.nn.Parameter(torch.tensor([0.0])) @staticmethod def _symbol_functions(**kwargs): @@ -144,59 +128,39 @@ class DeepONet(torch.nn.Module): "min": lambda x: torch.min(x, **kwargs).values, "max": lambda x: torch.max(x, **kwargs).values, } - - def _init_aggregator(self, aggregator, n_nets): + + def _init_aggregator(self, aggregator): aggregator_funcs = DeepONet._symbol_functions(dim=2) if aggregator in aggregator_funcs: aggregator_func = aggregator_funcs[aggregator] elif isinstance(aggregator, nn.Module) or is_function(aggregator): aggregator_func = aggregator - elif aggregator == "linear": - aggregator_func = nn.Linear(n_nets, len(self.output_variables)) else: raise ValueError(f"Unsupported aggregation: {str(aggregator)}") self._aggregator = aggregator_func - logging.info("Selected aggregator: %s", str(aggregator_func)) - # test the aggregator - test = self._aggregator(torch.ones((20, 3, n_nets))) - if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3): - raise ValueError( - f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}" - ) - def _init_reduction(self, reduction, hidden_size): - reduction_funcs = DeepONet._symbol_functions(dim=2) + def _init_reduction(self, reduction): + reduction_funcs = DeepONet._symbol_functions(dim=-1) if reduction in reduction_funcs: reduction_func = reduction_funcs[reduction] elif isinstance(reduction, nn.Module) or is_function(reduction): reduction_func = reduction - elif reduction == "linear": - reduction_func = nn.Linear(hidden_size, len(self.output_variables)) else: raise ValueError(f"Unsupported reduction: {reduction}") self._reduction = reduction_func - logging.info("Selected reduction: %s", str(reduction)) - - # test the reduction - test = self._reduction(torch.ones((20, 3, hidden_size))) - if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3): - msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}" - raise ValueError(msg) - - @staticmethod - def _all_nets_same_output_layer_size(nets): - size = nets[0].layers[-1].out_features - return all((net.layers[-1].out_features == size for net in nets[1:])) - - @property - def input_variables(self): - """The input variables of the model""" - nets_input_variables = map(lambda net: net.input_variables, self._nets) - return reduce(sum, nets_input_variables) + def _get_vars(self, x, indeces): + if isinstance(indeces[0], str): + check_consistency(x, LabelTensor) + return x.extract(indeces) + elif isinstance(indeces[0], int): + return x[..., indeces] + else: + raise RuntimeError('Not able to extract right indeces for tensor.') + def forward(self, x): """ Defines the computation performed at every call. @@ -206,23 +170,18 @@ class DeepONet(torch.nn.Module): :rtype: LabelTensor """ - nets_outputs = tuple( - net(x.extract(net.input_variables)) for net in self._nets - ) - # torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets) - aggregated = self._aggregator(torch.dstack(nets_outputs)) - # net_output_size = output_variables * hidden_size - aggregated_reshaped = aggregated.view( - (len(x), len(self.output_variables), -1) - ) - output_ = self._reduction(aggregated_reshaped) - output_ = torch.squeeze(torch.atleast_3d(output_), dim=2) + # forward pass + branch_output = self.branch_net(self._get_vars(x, self._branch_indeces)) + trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces)) - assert output_.shape == (len(x), len(self.output_variables)) + # aggregation + aggregated = self._aggregator(torch.dstack((branch_output, trunk_output))) - output_ = output_.as_subclass(LabelTensor) - output_.labels = self.output_variables + # reduce + output_ = self._reduction(aggregated).reshape(-1, 1) - output_ *= self.scale - output_ += self.trasl - return output_ + # scale and translate + output_ *= self._scale + output_ += self._trasl + + return output_ \ No newline at end of file diff --git a/tests/test_model/test_deeponet.py b/tests/test_model/test_deeponet.py index 1c7024d..b4ed899 100644 --- a/tests/test_model/test_deeponet.py +++ b/tests/test_model/test_deeponet.py @@ -3,29 +3,53 @@ import torch from pina import LabelTensor from pina.model import DeepONet -from pina.model import FeedForward as FFN +from pina.model import FeedForward data = torch.rand((20, 3)) input_vars = ['a', 'b', 'c'] -output_vars = ['d'] input_ = LabelTensor(data, input_vars) -# TODO -# def test_constructor(): -# branch = FFN(input_variables=['a', 'c'], output_variables=20) -# trunk = FFN(input_variables=['b'], output_variables=20) -# onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) +def test_constructor(): + branch_net = FeedForward(input_dimensons=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=2, output_dimensions=10) + DeepONet(branch_net=branch_net, + trunk_net=trunk_net, + input_indeces_branch_net=['a'], + input_indeces_trunk_net=['b', 'c'], + reduction='+', + aggregator='*') -# def test_constructor_fails_when_invalid_inner_layer_size(): -# branch = FFN(input_variables=['a', 'c'], output_variables=20) -# trunk = FFN(input_variables=['b'], output_variables=19) -# with pytest.raises(ValueError): -# DeepONet(nets=[trunk, branch], output_variables=output_vars) -# def test_forward(): -# branch = FFN(input_variables=['a', 'c'], output_variables=10) -# trunk = FFN(input_variables=['b'], output_variables=10) -# onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) -# output_ = onet(input_) -# assert output_.labels == output_vars +def test_constructor_fails_when_invalid_inner_layer_size(): + branch_net = FeedForward(input_dimensons=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=2, output_dimensions=8) + with pytest.raises(ValueError): + DeepONet(branch_net=branch_net, + trunk_net=trunk_net, + input_indeces_branch_net=['a'], + input_indeces_trunk_net=['b', 'c'], + reduction='+', + aggregator='*') + +def test_forward_extract_str(): + branch_net = FeedForward(input_dimensons=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=2, output_dimensions=10) + model = DeepONet(branch_net=branch_net, + trunk_net=trunk_net, + input_indeces_branch_net=['a'], + input_indeces_trunk_net=['b', 'c'], + reduction='+', + aggregator='*') + model(input_) + +def test_forward_extract_int(): + branch_net = FeedForward(input_dimensons=1, output_dimensions=10) + trunk_net = FeedForward(input_dimensons=2, output_dimensions=10) + model = DeepONet(branch_net=branch_net, + trunk_net=trunk_net, + input_indeces_branch_net=[0], + input_indeces_trunk_net=[1, 2], + reduction='+', + aggregator='*') + model(data)