From 7ce080fd8527796cff0611933d819e3d1b8e5ff9 Mon Sep 17 00:00:00 2001 From: Francesco Andreuzzi Date: Wed, 11 Jan 2023 12:07:19 +0100 Subject: [PATCH] Generic DeepONet (#68) * generic deeponet * example for generic deeponet * adapt tests to new interface --- examples/run_poisson_deeponet.py | 113 +++++++++++++++++++++ pina/model/deeponet.py | 165 ++++++++++++++++++++++++++----- pina/model/feed_forward.py | 10 +- pina/utils.py | 15 ++- tests/test_deeponet.py | 14 ++- 5 files changed, 280 insertions(+), 37 deletions(-) create mode 100644 examples/run_poisson_deeponet.py diff --git a/examples/run_poisson_deeponet.py b/examples/run_poisson_deeponet.py new file mode 100644 index 0000000..4f5e69d --- /dev/null +++ b/examples/run_poisson_deeponet.py @@ -0,0 +1,113 @@ +import argparse +import logging + +import torch +from problems.poisson import Poisson + +from pina import PINN, LabelTensor, Plotter +from pina.model.deeponet import DeepONet, check_combos, spawn_combo_networks + +logging.basicConfig( + filename="poisson_deeponet.log", filemode="w", level=logging.INFO +) + + +class SinFeature(torch.nn.Module): + """ + Feature: sin(x) + """ + + def __init__(self, label): + super().__init__() + + if not isinstance(label, (tuple, list)): + label = [label] + self._label = label + + def forward(self, x): + """ + Defines the computation performed at every call. + + :param LabelTensor x: the input tensor. + :return: the output computed by the model. + :rtype: LabelTensor + """ + t = torch.sin(x.extract(self._label) * torch.pi) + return LabelTensor(t, [f"sin({self._label})"]) + + +def prepare_deeponet_model(args, problem, extra_feature_combo_func=None): + combos = tuple(map(lambda combo: combo.split("-"), args.combos.split(","))) + check_combos(combos, problem.input_variables) + + extra_feature = extra_feature_combo_func if args.extra else None + networks = spawn_combo_networks( + combos=combos, + layers=list(map(int, args.layers.split(","))) if args.layers else [], + output_dimension=args.hidden * len(problem.output_variables), + func=torch.nn.Softplus, + extra_feature=extra_feature, + bias=not args.nobias, + ) + + return DeepONet( + networks, + problem.output_variables, + aggregator=args.aggregator, + reduction=args.reduction, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run PINA") + parser.add_argument("-s", "--save", action="store_true") + parser.add_argument("-l", "--load", action="store_true") + parser.add_argument("id_run", help="Run ID", type=int) + + parser.add_argument("--extra", help="Extra features", action="store_true") + parser.add_argument("--nobias", action="store_true") + parser.add_argument( + "--combos", + help="DeepONet internal network combinations", + type=str, + required=True, + ) + parser.add_argument( + "--aggregator", help="Aggregator for DeepONet", type=str, default="*" + ) + parser.add_argument( + "--reduction", help="Reduction for DeepONet", type=str, default="+" + ) + parser.add_argument( + "--hidden", + help="Number of variables in the hidden DeepONet layer", + type=int, + required=True, + ) + parser.add_argument( + "--layers", + help="Structure of the DeepONet partial layers", + type=str, + required=True, + ) + cli_args = parser.parse_args() + + poisson_problem = Poisson() + + model = prepare_deeponet_model( + cli_args, + poisson_problem, + extra_feature_combo_func=lambda combo: [SinFeature(combo)], + ) + pinn = PINN(poisson_problem, model, lr=0.01, regularizer=1e-8) + if cli_args.save: + pinn.span_pts( + 20, "grid", locations=["gamma1", "gamma2", "gamma3", "gamma4"] + ) + pinn.span_pts(20, "grid", locations=["D"]) + pinn.train(1.0e-10, 100) + pinn.save_state(f"pina.poisson_{cli_args.id_run}") + if cli_args.load: + pinn.load_state(f"pina.poisson_{cli_args.id_run}") + plotter = Plotter() + plotter.plot(pinn) diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index ca78898..f956248 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -1,8 +1,60 @@ """Module for DeepONet model""" +import logging +from functools import partial, reduce + import torch import torch.nn as nn from pina import LabelTensor +from pina.model import FeedForward +from pina.utils import is_function + + +def check_combos(combos, variables): + """ + Check that the given combinations are subsets (overlapping + is allowed) of the given set of variables. + + :param iterable(iterable(str)) combos: Combinations of variables. + :param iterable(str) variables: Variables. + """ + for combo in combos: + for variable in combo: + if variable not in variables: + raise ValueError( + f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable" + ) + + +def spawn_combo_networks( + combos, layers, output_dimension, func, extra_feature, bias=True +): + """ + Spawn internal networks for DeepONet based on the given combos. + + :param iterable(iterable(str)) combos: Combinations of variables. + :param iterable(int) layers: Size of hidden layers. + :param int output_dimension: Size of the output layer of the networks. + :param func: Nonlinearity. + :param extra_feature: Extra feature to be considered by the networks. + :param bool bias: Whether to consider bias or not. + """ + if is_function(extra_feature): + extra_feature_func = lambda _: extra_feature + else: + extra_feature_func = extra_feature + + return [ + FeedForward( + layers=layers, + input_variables=tuple(combo), + output_variables=output_dimension, + func=func, + extra_features=extra_feature_func(combo), + bias=bias, + ) + for combo in combos + ] class DeepONet(torch.nn.Module): @@ -18,23 +70,27 @@ class DeepONet(torch.nn.Module): `_ """ - def __init__(self, branch_net, trunk_net, output_variables): + + def __init__(self, nets, output_variables, aggregator="*", reduction="+"): """ - :param torch.nn.Module branch_net: the neural network to use as branch - model. It has to take as input a :class:`LabelTensor`. The number - of dimension of the output has to be the same of the `trunk_net`. - :param torch.nn.Module trunk_net: the neural network to use as trunk - model. It has to take as input a :class:`LabelTensor`. The number - of dimension of the output has to be the same of the `branch_net`. + :param iterable(torch.nn.Module) nets: Internal DeepONet networks + (branch and trunk in the original DeepONet). :param list(str) output_variables: the list containing the labels corresponding to the components of the output computed by the model. + :param string | callable aggregator: Aggregator to be used to aggregate + partial results from the modules in `nets`. Partial results are + aggregated component-wise. See :func:`_symbol_functions` for the + available default aggregators. + :param string | callable reduction: Reduction to be used to reduce + the aggregated result of the modules in `nets` to the desired output + dimension. See :func:`_symbol_functions` for the available default + reductions. :Example: >>> branch = FFN(input_variables=['a', 'c'], output_variables=20) >>> trunk = FFN(input_variables=['b'], output_variables=20) - >>> onet = DeepONet(trunk_net=trunk, branch_net=branch - >>> output_variables=output_vars) + >>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) DeepONet( (trunk_net): FeedForward( (extra_features): Sequential() @@ -63,22 +119,76 @@ class DeepONet(torch.nn.Module): self.output_variables = output_variables self.output_dimension = len(output_variables) - trunk_out_dim = trunk_net.layers[-1].out_features - branch_out_dim = branch_net.layers[-1].out_features + self._init_aggregator(aggregator, n_nets=len(nets)) + hidden_size = nets[0].model[-1].out_features + self._init_reduction(reduction, hidden_size=hidden_size) - if trunk_out_dim != branch_out_dim: - raise ValueError('Branch and trunk networks have not the same ' - 'output dimension.') + if not DeepONet._all_nets_same_output_layer_size(nets): + raise ValueError("All networks should have the same output size") + self._nets = torch.nn.ModuleList(nets) + logging.info("Combo DeepONet children: %s", list(self.children())) - self.trunk_net = trunk_net - self.branch_net = branch_net + @staticmethod + def _symbol_functions(**kwargs): + return { + "+": partial(torch.sum, **kwargs), + "*": partial(torch.prod, **kwargs), + "mean": partial(torch.mean, **kwargs), + "min": lambda x: torch.min(x, **kwargs).values, + "max": lambda x: torch.max(x, **kwargs).values, + } - self.reduction = nn.Linear(trunk_out_dim, self.output_dimension) + def _init_aggregator(self, aggregator, n_nets): + aggregator_funcs = DeepONet._symbol_functions(dim=2) + if aggregator in aggregator_funcs: + aggregator_func = aggregator_funcs[aggregator] + elif isinstance(aggregator, nn.Module) or is_function(aggregator): + aggregator_func = aggregator + elif aggregator == "linear": + aggregator_func = nn.Linear(n_nets, len(self.output_variables)) + else: + raise ValueError(f"Unsupported aggregation: {str(aggregator)}") + + self._aggregator = aggregator_func + logging.info("Selected aggregator: %s", str(aggregator_func)) + + # test the aggregator + test = self._aggregator(torch.ones((20, 3, n_nets))) + if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3): + raise ValueError( + f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}" + ) + + def _init_reduction(self, reduction, hidden_size): + reduction_funcs = DeepONet._symbol_functions(dim=2) + if reduction in reduction_funcs: + reduction_func = reduction_funcs[reduction] + elif isinstance(reduction, nn.Module) or is_function(reduction): + reduction_func = reduction + elif reduction == "linear": + reduction_func = nn.Linear(hidden_size, len(self.output_variables)) + else: + raise ValueError(f"Unsupported reduction: {reduction}") + + self._reduction = reduction_func + logging.info("Selected reduction: %s", str(reduction)) + + # test the reduction + test = self._reduction(torch.ones((20, 3, hidden_size))) + if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3): + msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}" + raise ValueError(msg) + + @staticmethod + def _all_nets_same_output_layer_size(nets): + size = nets[0].layers[-1].out_features + return all((net.layers[-1].out_features == size for net in nets[1:])) @property def input_variables(self): """The input variables of the model""" - return self.trunk_net.input_variables + self.branch_net.input_variables + nets_input_variables = map(lambda net: net.input_variables, self._nets) + return reduce(sum, nets_input_variables) def forward(self, x): """ @@ -89,15 +199,20 @@ class DeepONet(torch.nn.Module): :rtype: LabelTensor """ - branch_output = self.branch_net( - x.extract(self.branch_net.input_variables)) + nets_outputs = tuple( + net(x.extract(net.input_variables)) for net in self._nets + ) + # torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets) + aggregated = self._aggregator(torch.dstack(nets_outputs)) + # net_output_size = output_variables * hidden_size + aggregated_reshaped = aggregated.view( + (len(x), len(self.output_variables), -1) + ) + output_ = self._reduction(aggregated_reshaped) + output_ = torch.squeeze(torch.atleast_3d(output_), dim=2) - trunk_output = self.trunk_net( - x.extract(self.trunk_net.input_variables)) - - output_ = self.reduction(trunk_output * branch_output) + assert output_.shape == (len(x), len(self.output_variables)) output_ = output_.as_subclass(LabelTensor) output_.labels = self.output_variables - return output_ diff --git a/pina/model/feed_forward.py b/pina/model/feed_forward.py index ac5854c..74a8e01 100644 --- a/pina/model/feed_forward.py +++ b/pina/model/feed_forward.py @@ -26,9 +26,11 @@ class FeedForward(torch.nn.Module): `inner_size` are not considered. :param iterable(torch.nn.Module) extra_features: the additional input features to use ad augmented input. + :param bool bias: If `True` the MLP will consider some bias. """ def __init__(self, input_variables, output_variables, inner_size=20, - n_layers=2, func=nn.Tanh, layers=None, extra_features=None): + n_layers=2, func=nn.Tanh, layers=None, extra_features=None, + bias=True): """ """ super().__init__() @@ -62,7 +64,9 @@ class FeedForward(torch.nn.Module): self.layers = [] for i in range(len(tmp_layers)-1): - self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1])) + self.layers.append( + nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias) + ) if isinstance(func, list): self.functions = func @@ -94,7 +98,7 @@ class FeedForward(torch.nn.Module): if self.input_variables: x = x.extract(self.input_variables) - for i, feature in enumerate(self.extra_features): + for feature in self.extra_features: x = x.append(feature(x)) output = self.model(x).as_subclass(LabelTensor) diff --git a/pina/utils.py b/pina/utils.py index b2961c4..3f179e2 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -1,5 +1,7 @@ """Utils module""" from functools import reduce +import types + import torch from torch.utils.data import DataLoader, default_collate, ConcatDataset @@ -85,6 +87,17 @@ def torch_lhs(n, dim): return samples +def is_function(f): + """ + Checks whether the given object `f` is a function or lambda. + + :param object f: The object to be checked. + :return: `True` if `f` is a function, `False` otherwise. + :rtype: bool + """ + return type(f) == types.FunctionType or type(f) == types.LambdaType + + class PinaDataset(): def __init__(self, pinn) -> None: @@ -144,4 +157,4 @@ class PinaDataset(): return {self._location: tensor} def __len__(self): - return self._len \ No newline at end of file + return self._len diff --git a/tests/test_deeponet.py b/tests/test_deeponet.py index 86ac518..2ffefa7 100644 --- a/tests/test_deeponet.py +++ b/tests/test_deeponet.py @@ -1,9 +1,9 @@ -import torch import pytest +import torch from pina import LabelTensor -from pina.model import DeepONet, FeedForward as FFN - +from pina.model import DeepONet +from pina.model import FeedForward as FFN data = torch.rand((20, 3)) input_vars = ['a', 'b', 'c'] @@ -14,19 +14,17 @@ input_ = LabelTensor(data, input_vars) def test_constructor(): branch = FFN(input_variables=['a', 'c'], output_variables=20) trunk = FFN(input_variables=['b'], output_variables=20) - onet = DeepONet(trunk_net=trunk, branch_net=branch, - output_variables=output_vars) + onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) def test_constructor_fails_when_invalid_inner_layer_size(): branch = FFN(input_variables=['a', 'c'], output_variables=20) trunk = FFN(input_variables=['b'], output_variables=19) with pytest.raises(ValueError): - DeepONet(trunk_net=trunk, branch_net=branch, output_variables=output_vars) + DeepONet(nets=[trunk, branch], output_variables=output_vars) def test_forward(): branch = FFN(input_variables=['a', 'c'], output_variables=10) trunk = FFN(input_variables=['b'], output_variables=10) - onet = DeepONet(trunk_net=trunk, branch_net=branch, - output_variables=output_vars) + onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) output_ = onet(input_) assert output_.labels == output_vars