Generic DeepONet (#68)

* generic deeponet
* example for generic deeponet
* adapt tests to new interface
This commit is contained in:
Francesco Andreuzzi
2023-01-11 12:07:19 +01:00
committed by GitHub
parent e227700fbc
commit 7ce080fd85
5 changed files with 280 additions and 37 deletions

View File

@@ -0,0 +1,113 @@
import argparse
import logging
import torch
from problems.poisson import Poisson
from pina import PINN, LabelTensor, Plotter
from pina.model.deeponet import DeepONet, check_combos, spawn_combo_networks
logging.basicConfig(
filename="poisson_deeponet.log", filemode="w", level=logging.INFO
)
class SinFeature(torch.nn.Module):
"""
Feature: sin(x)
"""
def __init__(self, label):
super().__init__()
if not isinstance(label, (tuple, list)):
label = [label]
self._label = label
def forward(self, x):
"""
Defines the computation performed at every call.
:param LabelTensor x: the input tensor.
:return: the output computed by the model.
:rtype: LabelTensor
"""
t = torch.sin(x.extract(self._label) * torch.pi)
return LabelTensor(t, [f"sin({self._label})"])
def prepare_deeponet_model(args, problem, extra_feature_combo_func=None):
combos = tuple(map(lambda combo: combo.split("-"), args.combos.split(",")))
check_combos(combos, problem.input_variables)
extra_feature = extra_feature_combo_func if args.extra else None
networks = spawn_combo_networks(
combos=combos,
layers=list(map(int, args.layers.split(","))) if args.layers else [],
output_dimension=args.hidden * len(problem.output_variables),
func=torch.nn.Softplus,
extra_feature=extra_feature,
bias=not args.nobias,
)
return DeepONet(
networks,
problem.output_variables,
aggregator=args.aggregator,
reduction=args.reduction,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run PINA")
parser.add_argument("-s", "--save", action="store_true")
parser.add_argument("-l", "--load", action="store_true")
parser.add_argument("id_run", help="Run ID", type=int)
parser.add_argument("--extra", help="Extra features", action="store_true")
parser.add_argument("--nobias", action="store_true")
parser.add_argument(
"--combos",
help="DeepONet internal network combinations",
type=str,
required=True,
)
parser.add_argument(
"--aggregator", help="Aggregator for DeepONet", type=str, default="*"
)
parser.add_argument(
"--reduction", help="Reduction for DeepONet", type=str, default="+"
)
parser.add_argument(
"--hidden",
help="Number of variables in the hidden DeepONet layer",
type=int,
required=True,
)
parser.add_argument(
"--layers",
help="Structure of the DeepONet partial layers",
type=str,
required=True,
)
cli_args = parser.parse_args()
poisson_problem = Poisson()
model = prepare_deeponet_model(
cli_args,
poisson_problem,
extra_feature_combo_func=lambda combo: [SinFeature(combo)],
)
pinn = PINN(poisson_problem, model, lr=0.01, regularizer=1e-8)
if cli_args.save:
pinn.span_pts(
20, "grid", locations=["gamma1", "gamma2", "gamma3", "gamma4"]
)
pinn.span_pts(20, "grid", locations=["D"])
pinn.train(1.0e-10, 100)
pinn.save_state(f"pina.poisson_{cli_args.id_run}")
if cli_args.load:
pinn.load_state(f"pina.poisson_{cli_args.id_run}")
plotter = Plotter()
plotter.plot(pinn)

View File

@@ -1,8 +1,60 @@
"""Module for DeepONet model""" """Module for DeepONet model"""
import logging
from functools import partial, reduce
import torch import torch
import torch.nn as nn import torch.nn as nn
from pina import LabelTensor from pina import LabelTensor
from pina.model import FeedForward
from pina.utils import is_function
def check_combos(combos, variables):
"""
Check that the given combinations are subsets (overlapping
is allowed) of the given set of variables.
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(str) variables: Variables.
"""
for combo in combos:
for variable in combo:
if variable not in variables:
raise ValueError(
f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable"
)
def spawn_combo_networks(
combos, layers, output_dimension, func, extra_feature, bias=True
):
"""
Spawn internal networks for DeepONet based on the given combos.
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(int) layers: Size of hidden layers.
:param int output_dimension: Size of the output layer of the networks.
:param func: Nonlinearity.
:param extra_feature: Extra feature to be considered by the networks.
:param bool bias: Whether to consider bias or not.
"""
if is_function(extra_feature):
extra_feature_func = lambda _: extra_feature
else:
extra_feature_func = extra_feature
return [
FeedForward(
layers=layers,
input_variables=tuple(combo),
output_variables=output_dimension,
func=func,
extra_features=extra_feature_func(combo),
bias=bias,
)
for combo in combos
]
class DeepONet(torch.nn.Module): class DeepONet(torch.nn.Module):
@@ -18,23 +70,27 @@ class DeepONet(torch.nn.Module):
<https://doi.org/10.1038/s42256-021-00302-5>`_ <https://doi.org/10.1038/s42256-021-00302-5>`_
""" """
def __init__(self, branch_net, trunk_net, output_variables):
def __init__(self, nets, output_variables, aggregator="*", reduction="+"):
""" """
:param torch.nn.Module branch_net: the neural network to use as branch :param iterable(torch.nn.Module) nets: Internal DeepONet networks
model. It has to take as input a :class:`LabelTensor`. The number (branch and trunk in the original DeepONet).
of dimension of the output has to be the same of the `trunk_net`.
:param torch.nn.Module trunk_net: the neural network to use as trunk
model. It has to take as input a :class:`LabelTensor`. The number
of dimension of the output has to be the same of the `branch_net`.
:param list(str) output_variables: the list containing the labels :param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the corresponding to the components of the output computed by the
model. model.
:param string | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. See :func:`_symbol_functions` for the
available default aggregators.
:param string | callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. See :func:`_symbol_functions` for the available default
reductions.
:Example: :Example:
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20) >>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
>>> trunk = FFN(input_variables=['b'], output_variables=20) >>> trunk = FFN(input_variables=['b'], output_variables=20)
>>> onet = DeepONet(trunk_net=trunk, branch_net=branch >>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
>>> output_variables=output_vars)
DeepONet( DeepONet(
(trunk_net): FeedForward( (trunk_net): FeedForward(
(extra_features): Sequential() (extra_features): Sequential()
@@ -63,22 +119,76 @@ class DeepONet(torch.nn.Module):
self.output_variables = output_variables self.output_variables = output_variables
self.output_dimension = len(output_variables) self.output_dimension = len(output_variables)
trunk_out_dim = trunk_net.layers[-1].out_features self._init_aggregator(aggregator, n_nets=len(nets))
branch_out_dim = branch_net.layers[-1].out_features hidden_size = nets[0].model[-1].out_features
self._init_reduction(reduction, hidden_size=hidden_size)
if trunk_out_dim != branch_out_dim: if not DeepONet._all_nets_same_output_layer_size(nets):
raise ValueError('Branch and trunk networks have not the same ' raise ValueError("All networks should have the same output size")
'output dimension.') self._nets = torch.nn.ModuleList(nets)
logging.info("Combo DeepONet children: %s", list(self.children()))
self.trunk_net = trunk_net @staticmethod
self.branch_net = branch_net def _symbol_functions(**kwargs):
return {
"+": partial(torch.sum, **kwargs),
"*": partial(torch.prod, **kwargs),
"mean": partial(torch.mean, **kwargs),
"min": lambda x: torch.min(x, **kwargs).values,
"max": lambda x: torch.max(x, **kwargs).values,
}
self.reduction = nn.Linear(trunk_out_dim, self.output_dimension) def _init_aggregator(self, aggregator, n_nets):
aggregator_funcs = DeepONet._symbol_functions(dim=2)
if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator]
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
aggregator_func = aggregator
elif aggregator == "linear":
aggregator_func = nn.Linear(n_nets, len(self.output_variables))
else:
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
self._aggregator = aggregator_func
logging.info("Selected aggregator: %s", str(aggregator_func))
# test the aggregator
test = self._aggregator(torch.ones((20, 3, n_nets)))
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
raise ValueError(
f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}"
)
def _init_reduction(self, reduction, hidden_size):
reduction_funcs = DeepONet._symbol_functions(dim=2)
if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction]
elif isinstance(reduction, nn.Module) or is_function(reduction):
reduction_func = reduction
elif reduction == "linear":
reduction_func = nn.Linear(hidden_size, len(self.output_variables))
else:
raise ValueError(f"Unsupported reduction: {reduction}")
self._reduction = reduction_func
logging.info("Selected reduction: %s", str(reduction))
# test the reduction
test = self._reduction(torch.ones((20, 3, hidden_size)))
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}"
raise ValueError(msg)
@staticmethod
def _all_nets_same_output_layer_size(nets):
size = nets[0].layers[-1].out_features
return all((net.layers[-1].out_features == size for net in nets[1:]))
@property @property
def input_variables(self): def input_variables(self):
"""The input variables of the model""" """The input variables of the model"""
return self.trunk_net.input_variables + self.branch_net.input_variables nets_input_variables = map(lambda net: net.input_variables, self._nets)
return reduce(sum, nets_input_variables)
def forward(self, x): def forward(self, x):
""" """
@@ -89,15 +199,20 @@ class DeepONet(torch.nn.Module):
:rtype: LabelTensor :rtype: LabelTensor
""" """
branch_output = self.branch_net( nets_outputs = tuple(
x.extract(self.branch_net.input_variables)) net(x.extract(net.input_variables)) for net in self._nets
)
# torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets)
aggregated = self._aggregator(torch.dstack(nets_outputs))
# net_output_size = output_variables * hidden_size
aggregated_reshaped = aggregated.view(
(len(x), len(self.output_variables), -1)
)
output_ = self._reduction(aggregated_reshaped)
output_ = torch.squeeze(torch.atleast_3d(output_), dim=2)
trunk_output = self.trunk_net( assert output_.shape == (len(x), len(self.output_variables))
x.extract(self.trunk_net.input_variables))
output_ = self.reduction(trunk_output * branch_output)
output_ = output_.as_subclass(LabelTensor) output_ = output_.as_subclass(LabelTensor)
output_.labels = self.output_variables output_.labels = self.output_variables
return output_ return output_

View File

@@ -26,9 +26,11 @@ class FeedForward(torch.nn.Module):
`inner_size` are not considered. `inner_size` are not considered.
:param iterable(torch.nn.Module) extra_features: the additional input :param iterable(torch.nn.Module) extra_features: the additional input
features to use ad augmented input. features to use ad augmented input.
:param bool bias: If `True` the MLP will consider some bias.
""" """
def __init__(self, input_variables, output_variables, inner_size=20, def __init__(self, input_variables, output_variables, inner_size=20,
n_layers=2, func=nn.Tanh, layers=None, extra_features=None): n_layers=2, func=nn.Tanh, layers=None, extra_features=None,
bias=True):
""" """
""" """
super().__init__() super().__init__()
@@ -62,7 +64,9 @@ class FeedForward(torch.nn.Module):
self.layers = [] self.layers = []
for i in range(len(tmp_layers)-1): for i in range(len(tmp_layers)-1):
self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1])) self.layers.append(
nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias)
)
if isinstance(func, list): if isinstance(func, list):
self.functions = func self.functions = func
@@ -94,7 +98,7 @@ class FeedForward(torch.nn.Module):
if self.input_variables: if self.input_variables:
x = x.extract(self.input_variables) x = x.extract(self.input_variables)
for i, feature in enumerate(self.extra_features): for feature in self.extra_features:
x = x.append(feature(x)) x = x.append(feature(x))
output = self.model(x).as_subclass(LabelTensor) output = self.model(x).as_subclass(LabelTensor)

View File

@@ -1,5 +1,7 @@
"""Utils module""" """Utils module"""
from functools import reduce from functools import reduce
import types
import torch import torch
from torch.utils.data import DataLoader, default_collate, ConcatDataset from torch.utils.data import DataLoader, default_collate, ConcatDataset
@@ -85,6 +87,17 @@ def torch_lhs(n, dim):
return samples return samples
def is_function(f):
"""
Checks whether the given object `f` is a function or lambda.
:param object f: The object to be checked.
:return: `True` if `f` is a function, `False` otherwise.
:rtype: bool
"""
return type(f) == types.FunctionType or type(f) == types.LambdaType
class PinaDataset(): class PinaDataset():
def __init__(self, pinn) -> None: def __init__(self, pinn) -> None:
@@ -144,4 +157,4 @@ class PinaDataset():
return {self._location: tensor} return {self._location: tensor}
def __len__(self): def __len__(self):
return self._len return self._len

View File

@@ -1,9 +1,9 @@
import torch
import pytest import pytest
import torch
from pina import LabelTensor from pina import LabelTensor
from pina.model import DeepONet, FeedForward as FFN from pina.model import DeepONet
from pina.model import FeedForward as FFN
data = torch.rand((20, 3)) data = torch.rand((20, 3))
input_vars = ['a', 'b', 'c'] input_vars = ['a', 'b', 'c']
@@ -14,19 +14,17 @@ input_ = LabelTensor(data, input_vars)
def test_constructor(): def test_constructor():
branch = FFN(input_variables=['a', 'c'], output_variables=20) branch = FFN(input_variables=['a', 'c'], output_variables=20)
trunk = FFN(input_variables=['b'], output_variables=20) trunk = FFN(input_variables=['b'], output_variables=20)
onet = DeepONet(trunk_net=trunk, branch_net=branch, onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
output_variables=output_vars)
def test_constructor_fails_when_invalid_inner_layer_size(): def test_constructor_fails_when_invalid_inner_layer_size():
branch = FFN(input_variables=['a', 'c'], output_variables=20) branch = FFN(input_variables=['a', 'c'], output_variables=20)
trunk = FFN(input_variables=['b'], output_variables=19) trunk = FFN(input_variables=['b'], output_variables=19)
with pytest.raises(ValueError): with pytest.raises(ValueError):
DeepONet(trunk_net=trunk, branch_net=branch, output_variables=output_vars) DeepONet(nets=[trunk, branch], output_variables=output_vars)
def test_forward(): def test_forward():
branch = FFN(input_variables=['a', 'c'], output_variables=10) branch = FFN(input_variables=['a', 'c'], output_variables=10)
trunk = FFN(input_variables=['b'], output_variables=10) trunk = FFN(input_variables=['b'], output_variables=10)
onet = DeepONet(trunk_net=trunk, branch_net=branch, onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
output_variables=output_vars)
output_ = onet(input_) output_ = onet(input_)
assert output_.labels == output_vars assert output_.labels == output_vars