Generic DeepONet (#68)

* generic deeponet
* example for generic deeponet
* adapt tests to new interface
This commit is contained in:
Francesco Andreuzzi
2023-01-11 12:07:19 +01:00
committed by GitHub
parent e227700fbc
commit 7ce080fd85
5 changed files with 280 additions and 37 deletions

View File

@@ -1,8 +1,60 @@
"""Module for DeepONet model"""
import logging
from functools import partial, reduce
import torch
import torch.nn as nn
from pina import LabelTensor
from pina.model import FeedForward
from pina.utils import is_function
def check_combos(combos, variables):
"""
Check that the given combinations are subsets (overlapping
is allowed) of the given set of variables.
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(str) variables: Variables.
"""
for combo in combos:
for variable in combo:
if variable not in variables:
raise ValueError(
f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable"
)
def spawn_combo_networks(
combos, layers, output_dimension, func, extra_feature, bias=True
):
"""
Spawn internal networks for DeepONet based on the given combos.
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(int) layers: Size of hidden layers.
:param int output_dimension: Size of the output layer of the networks.
:param func: Nonlinearity.
:param extra_feature: Extra feature to be considered by the networks.
:param bool bias: Whether to consider bias or not.
"""
if is_function(extra_feature):
extra_feature_func = lambda _: extra_feature
else:
extra_feature_func = extra_feature
return [
FeedForward(
layers=layers,
input_variables=tuple(combo),
output_variables=output_dimension,
func=func,
extra_features=extra_feature_func(combo),
bias=bias,
)
for combo in combos
]
class DeepONet(torch.nn.Module):
@@ -18,23 +70,27 @@ class DeepONet(torch.nn.Module):
<https://doi.org/10.1038/s42256-021-00302-5>`_
"""
def __init__(self, branch_net, trunk_net, output_variables):
def __init__(self, nets, output_variables, aggregator="*", reduction="+"):
"""
:param torch.nn.Module branch_net: the neural network to use as branch
model. It has to take as input a :class:`LabelTensor`. The number
of dimension of the output has to be the same of the `trunk_net`.
:param torch.nn.Module trunk_net: the neural network to use as trunk
model. It has to take as input a :class:`LabelTensor`. The number
of dimension of the output has to be the same of the `branch_net`.
:param iterable(torch.nn.Module) nets: Internal DeepONet networks
(branch and trunk in the original DeepONet).
:param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the
model.
:param string | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. See :func:`_symbol_functions` for the
available default aggregators.
:param string | callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. See :func:`_symbol_functions` for the available default
reductions.
:Example:
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
>>> trunk = FFN(input_variables=['b'], output_variables=20)
>>> onet = DeepONet(trunk_net=trunk, branch_net=branch
>>> output_variables=output_vars)
>>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
DeepONet(
(trunk_net): FeedForward(
(extra_features): Sequential()
@@ -63,22 +119,76 @@ class DeepONet(torch.nn.Module):
self.output_variables = output_variables
self.output_dimension = len(output_variables)
trunk_out_dim = trunk_net.layers[-1].out_features
branch_out_dim = branch_net.layers[-1].out_features
self._init_aggregator(aggregator, n_nets=len(nets))
hidden_size = nets[0].model[-1].out_features
self._init_reduction(reduction, hidden_size=hidden_size)
if trunk_out_dim != branch_out_dim:
raise ValueError('Branch and trunk networks have not the same '
'output dimension.')
if not DeepONet._all_nets_same_output_layer_size(nets):
raise ValueError("All networks should have the same output size")
self._nets = torch.nn.ModuleList(nets)
logging.info("Combo DeepONet children: %s", list(self.children()))
self.trunk_net = trunk_net
self.branch_net = branch_net
@staticmethod
def _symbol_functions(**kwargs):
return {
"+": partial(torch.sum, **kwargs),
"*": partial(torch.prod, **kwargs),
"mean": partial(torch.mean, **kwargs),
"min": lambda x: torch.min(x, **kwargs).values,
"max": lambda x: torch.max(x, **kwargs).values,
}
self.reduction = nn.Linear(trunk_out_dim, self.output_dimension)
def _init_aggregator(self, aggregator, n_nets):
aggregator_funcs = DeepONet._symbol_functions(dim=2)
if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator]
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
aggregator_func = aggregator
elif aggregator == "linear":
aggregator_func = nn.Linear(n_nets, len(self.output_variables))
else:
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
self._aggregator = aggregator_func
logging.info("Selected aggregator: %s", str(aggregator_func))
# test the aggregator
test = self._aggregator(torch.ones((20, 3, n_nets)))
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
raise ValueError(
f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}"
)
def _init_reduction(self, reduction, hidden_size):
reduction_funcs = DeepONet._symbol_functions(dim=2)
if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction]
elif isinstance(reduction, nn.Module) or is_function(reduction):
reduction_func = reduction
elif reduction == "linear":
reduction_func = nn.Linear(hidden_size, len(self.output_variables))
else:
raise ValueError(f"Unsupported reduction: {reduction}")
self._reduction = reduction_func
logging.info("Selected reduction: %s", str(reduction))
# test the reduction
test = self._reduction(torch.ones((20, 3, hidden_size)))
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}"
raise ValueError(msg)
@staticmethod
def _all_nets_same_output_layer_size(nets):
size = nets[0].layers[-1].out_features
return all((net.layers[-1].out_features == size for net in nets[1:]))
@property
def input_variables(self):
"""The input variables of the model"""
return self.trunk_net.input_variables + self.branch_net.input_variables
nets_input_variables = map(lambda net: net.input_variables, self._nets)
return reduce(sum, nets_input_variables)
def forward(self, x):
"""
@@ -89,15 +199,20 @@ class DeepONet(torch.nn.Module):
:rtype: LabelTensor
"""
branch_output = self.branch_net(
x.extract(self.branch_net.input_variables))
nets_outputs = tuple(
net(x.extract(net.input_variables)) for net in self._nets
)
# torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets)
aggregated = self._aggregator(torch.dstack(nets_outputs))
# net_output_size = output_variables * hidden_size
aggregated_reshaped = aggregated.view(
(len(x), len(self.output_variables), -1)
)
output_ = self._reduction(aggregated_reshaped)
output_ = torch.squeeze(torch.atleast_3d(output_), dim=2)
trunk_output = self.trunk_net(
x.extract(self.trunk_net.input_variables))
output_ = self.reduction(trunk_output * branch_output)
assert output_.shape == (len(x), len(self.output_variables))
output_ = output_.as_subclass(LabelTensor)
output_.labels = self.output_variables
return output_

View File

@@ -26,9 +26,11 @@ class FeedForward(torch.nn.Module):
`inner_size` are not considered.
:param iterable(torch.nn.Module) extra_features: the additional input
features to use ad augmented input.
:param bool bias: If `True` the MLP will consider some bias.
"""
def __init__(self, input_variables, output_variables, inner_size=20,
n_layers=2, func=nn.Tanh, layers=None, extra_features=None):
n_layers=2, func=nn.Tanh, layers=None, extra_features=None,
bias=True):
"""
"""
super().__init__()
@@ -62,7 +64,9 @@ class FeedForward(torch.nn.Module):
self.layers = []
for i in range(len(tmp_layers)-1):
self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1]))
self.layers.append(
nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias)
)
if isinstance(func, list):
self.functions = func
@@ -94,7 +98,7 @@ class FeedForward(torch.nn.Module):
if self.input_variables:
x = x.extract(self.input_variables)
for i, feature in enumerate(self.extra_features):
for feature in self.extra_features:
x = x.append(feature(x))
output = self.model(x).as_subclass(LabelTensor)