Generic DeepONet (#68)
* generic deeponet * example for generic deeponet * adapt tests to new interface
This commit is contained in:
committed by
GitHub
parent
e227700fbc
commit
7ce080fd85
113
examples/run_poisson_deeponet.py
Normal file
113
examples/run_poisson_deeponet.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from problems.poisson import Poisson
|
||||||
|
|
||||||
|
from pina import PINN, LabelTensor, Plotter
|
||||||
|
from pina.model.deeponet import DeepONet, check_combos, spawn_combo_networks
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
filename="poisson_deeponet.log", filemode="w", level=logging.INFO
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SinFeature(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Feature: sin(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, label):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not isinstance(label, (tuple, list)):
|
||||||
|
label = [label]
|
||||||
|
self._label = label
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Defines the computation performed at every call.
|
||||||
|
|
||||||
|
:param LabelTensor x: the input tensor.
|
||||||
|
:return: the output computed by the model.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
t = torch.sin(x.extract(self._label) * torch.pi)
|
||||||
|
return LabelTensor(t, [f"sin({self._label})"])
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_deeponet_model(args, problem, extra_feature_combo_func=None):
|
||||||
|
combos = tuple(map(lambda combo: combo.split("-"), args.combos.split(",")))
|
||||||
|
check_combos(combos, problem.input_variables)
|
||||||
|
|
||||||
|
extra_feature = extra_feature_combo_func if args.extra else None
|
||||||
|
networks = spawn_combo_networks(
|
||||||
|
combos=combos,
|
||||||
|
layers=list(map(int, args.layers.split(","))) if args.layers else [],
|
||||||
|
output_dimension=args.hidden * len(problem.output_variables),
|
||||||
|
func=torch.nn.Softplus,
|
||||||
|
extra_feature=extra_feature,
|
||||||
|
bias=not args.nobias,
|
||||||
|
)
|
||||||
|
|
||||||
|
return DeepONet(
|
||||||
|
networks,
|
||||||
|
problem.output_variables,
|
||||||
|
aggregator=args.aggregator,
|
||||||
|
reduction=args.reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run PINA")
|
||||||
|
parser.add_argument("-s", "--save", action="store_true")
|
||||||
|
parser.add_argument("-l", "--load", action="store_true")
|
||||||
|
parser.add_argument("id_run", help="Run ID", type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--extra", help="Extra features", action="store_true")
|
||||||
|
parser.add_argument("--nobias", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--combos",
|
||||||
|
help="DeepONet internal network combinations",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--aggregator", help="Aggregator for DeepONet", type=str, default="*"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reduction", help="Reduction for DeepONet", type=str, default="+"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden",
|
||||||
|
help="Number of variables in the hidden DeepONet layer",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--layers",
|
||||||
|
help="Structure of the DeepONet partial layers",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
cli_args = parser.parse_args()
|
||||||
|
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
|
||||||
|
model = prepare_deeponet_model(
|
||||||
|
cli_args,
|
||||||
|
poisson_problem,
|
||||||
|
extra_feature_combo_func=lambda combo: [SinFeature(combo)],
|
||||||
|
)
|
||||||
|
pinn = PINN(poisson_problem, model, lr=0.01, regularizer=1e-8)
|
||||||
|
if cli_args.save:
|
||||||
|
pinn.span_pts(
|
||||||
|
20, "grid", locations=["gamma1", "gamma2", "gamma3", "gamma4"]
|
||||||
|
)
|
||||||
|
pinn.span_pts(20, "grid", locations=["D"])
|
||||||
|
pinn.train(1.0e-10, 100)
|
||||||
|
pinn.save_state(f"pina.poisson_{cli_args.id_run}")
|
||||||
|
if cli_args.load:
|
||||||
|
pinn.load_state(f"pina.poisson_{cli_args.id_run}")
|
||||||
|
plotter = Plotter()
|
||||||
|
plotter.plot(pinn)
|
||||||
@@ -1,8 +1,60 @@
|
|||||||
"""Module for DeepONet model"""
|
"""Module for DeepONet model"""
|
||||||
|
import logging
|
||||||
|
from functools import partial, reduce
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.utils import is_function
|
||||||
|
|
||||||
|
|
||||||
|
def check_combos(combos, variables):
|
||||||
|
"""
|
||||||
|
Check that the given combinations are subsets (overlapping
|
||||||
|
is allowed) of the given set of variables.
|
||||||
|
|
||||||
|
:param iterable(iterable(str)) combos: Combinations of variables.
|
||||||
|
:param iterable(str) variables: Variables.
|
||||||
|
"""
|
||||||
|
for combo in combos:
|
||||||
|
for variable in combo:
|
||||||
|
if variable not in variables:
|
||||||
|
raise ValueError(
|
||||||
|
f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def spawn_combo_networks(
|
||||||
|
combos, layers, output_dimension, func, extra_feature, bias=True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Spawn internal networks for DeepONet based on the given combos.
|
||||||
|
|
||||||
|
:param iterable(iterable(str)) combos: Combinations of variables.
|
||||||
|
:param iterable(int) layers: Size of hidden layers.
|
||||||
|
:param int output_dimension: Size of the output layer of the networks.
|
||||||
|
:param func: Nonlinearity.
|
||||||
|
:param extra_feature: Extra feature to be considered by the networks.
|
||||||
|
:param bool bias: Whether to consider bias or not.
|
||||||
|
"""
|
||||||
|
if is_function(extra_feature):
|
||||||
|
extra_feature_func = lambda _: extra_feature
|
||||||
|
else:
|
||||||
|
extra_feature_func = extra_feature
|
||||||
|
|
||||||
|
return [
|
||||||
|
FeedForward(
|
||||||
|
layers=layers,
|
||||||
|
input_variables=tuple(combo),
|
||||||
|
output_variables=output_dimension,
|
||||||
|
func=func,
|
||||||
|
extra_features=extra_feature_func(combo),
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
for combo in combos
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DeepONet(torch.nn.Module):
|
class DeepONet(torch.nn.Module):
|
||||||
@@ -18,23 +70,27 @@ class DeepONet(torch.nn.Module):
|
|||||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, branch_net, trunk_net, output_variables):
|
|
||||||
|
def __init__(self, nets, output_variables, aggregator="*", reduction="+"):
|
||||||
"""
|
"""
|
||||||
:param torch.nn.Module branch_net: the neural network to use as branch
|
:param iterable(torch.nn.Module) nets: Internal DeepONet networks
|
||||||
model. It has to take as input a :class:`LabelTensor`. The number
|
(branch and trunk in the original DeepONet).
|
||||||
of dimension of the output has to be the same of the `trunk_net`.
|
|
||||||
:param torch.nn.Module trunk_net: the neural network to use as trunk
|
|
||||||
model. It has to take as input a :class:`LabelTensor`. The number
|
|
||||||
of dimension of the output has to be the same of the `branch_net`.
|
|
||||||
:param list(str) output_variables: the list containing the labels
|
:param list(str) output_variables: the list containing the labels
|
||||||
corresponding to the components of the output computed by the
|
corresponding to the components of the output computed by the
|
||||||
model.
|
model.
|
||||||
|
:param string | callable aggregator: Aggregator to be used to aggregate
|
||||||
|
partial results from the modules in `nets`. Partial results are
|
||||||
|
aggregated component-wise. See :func:`_symbol_functions` for the
|
||||||
|
available default aggregators.
|
||||||
|
:param string | callable reduction: Reduction to be used to reduce
|
||||||
|
the aggregated result of the modules in `nets` to the desired output
|
||||||
|
dimension. See :func:`_symbol_functions` for the available default
|
||||||
|
reductions.
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
||||||
>>> trunk = FFN(input_variables=['b'], output_variables=20)
|
>>> trunk = FFN(input_variables=['b'], output_variables=20)
|
||||||
>>> onet = DeepONet(trunk_net=trunk, branch_net=branch
|
>>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
|
||||||
>>> output_variables=output_vars)
|
|
||||||
DeepONet(
|
DeepONet(
|
||||||
(trunk_net): FeedForward(
|
(trunk_net): FeedForward(
|
||||||
(extra_features): Sequential()
|
(extra_features): Sequential()
|
||||||
@@ -63,22 +119,76 @@ class DeepONet(torch.nn.Module):
|
|||||||
self.output_variables = output_variables
|
self.output_variables = output_variables
|
||||||
self.output_dimension = len(output_variables)
|
self.output_dimension = len(output_variables)
|
||||||
|
|
||||||
trunk_out_dim = trunk_net.layers[-1].out_features
|
self._init_aggregator(aggregator, n_nets=len(nets))
|
||||||
branch_out_dim = branch_net.layers[-1].out_features
|
hidden_size = nets[0].model[-1].out_features
|
||||||
|
self._init_reduction(reduction, hidden_size=hidden_size)
|
||||||
|
|
||||||
if trunk_out_dim != branch_out_dim:
|
if not DeepONet._all_nets_same_output_layer_size(nets):
|
||||||
raise ValueError('Branch and trunk networks have not the same '
|
raise ValueError("All networks should have the same output size")
|
||||||
'output dimension.')
|
self._nets = torch.nn.ModuleList(nets)
|
||||||
|
logging.info("Combo DeepONet children: %s", list(self.children()))
|
||||||
|
|
||||||
self.trunk_net = trunk_net
|
@staticmethod
|
||||||
self.branch_net = branch_net
|
def _symbol_functions(**kwargs):
|
||||||
|
return {
|
||||||
|
"+": partial(torch.sum, **kwargs),
|
||||||
|
"*": partial(torch.prod, **kwargs),
|
||||||
|
"mean": partial(torch.mean, **kwargs),
|
||||||
|
"min": lambda x: torch.min(x, **kwargs).values,
|
||||||
|
"max": lambda x: torch.max(x, **kwargs).values,
|
||||||
|
}
|
||||||
|
|
||||||
self.reduction = nn.Linear(trunk_out_dim, self.output_dimension)
|
def _init_aggregator(self, aggregator, n_nets):
|
||||||
|
aggregator_funcs = DeepONet._symbol_functions(dim=2)
|
||||||
|
if aggregator in aggregator_funcs:
|
||||||
|
aggregator_func = aggregator_funcs[aggregator]
|
||||||
|
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
|
||||||
|
aggregator_func = aggregator
|
||||||
|
elif aggregator == "linear":
|
||||||
|
aggregator_func = nn.Linear(n_nets, len(self.output_variables))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
|
||||||
|
|
||||||
|
self._aggregator = aggregator_func
|
||||||
|
logging.info("Selected aggregator: %s", str(aggregator_func))
|
||||||
|
|
||||||
|
# test the aggregator
|
||||||
|
test = self._aggregator(torch.ones((20, 3, n_nets)))
|
||||||
|
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_reduction(self, reduction, hidden_size):
|
||||||
|
reduction_funcs = DeepONet._symbol_functions(dim=2)
|
||||||
|
if reduction in reduction_funcs:
|
||||||
|
reduction_func = reduction_funcs[reduction]
|
||||||
|
elif isinstance(reduction, nn.Module) or is_function(reduction):
|
||||||
|
reduction_func = reduction
|
||||||
|
elif reduction == "linear":
|
||||||
|
reduction_func = nn.Linear(hidden_size, len(self.output_variables))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported reduction: {reduction}")
|
||||||
|
|
||||||
|
self._reduction = reduction_func
|
||||||
|
logging.info("Selected reduction: %s", str(reduction))
|
||||||
|
|
||||||
|
# test the reduction
|
||||||
|
test = self._reduction(torch.ones((20, 3, hidden_size)))
|
||||||
|
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
|
||||||
|
msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _all_nets_same_output_layer_size(nets):
|
||||||
|
size = nets[0].layers[-1].out_features
|
||||||
|
return all((net.layers[-1].out_features == size for net in nets[1:]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_variables(self):
|
def input_variables(self):
|
||||||
"""The input variables of the model"""
|
"""The input variables of the model"""
|
||||||
return self.trunk_net.input_variables + self.branch_net.input_variables
|
nets_input_variables = map(lambda net: net.input_variables, self._nets)
|
||||||
|
return reduce(sum, nets_input_variables)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
@@ -89,15 +199,20 @@ class DeepONet(torch.nn.Module):
|
|||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
branch_output = self.branch_net(
|
nets_outputs = tuple(
|
||||||
x.extract(self.branch_net.input_variables))
|
net(x.extract(net.input_variables)) for net in self._nets
|
||||||
|
)
|
||||||
|
# torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets)
|
||||||
|
aggregated = self._aggregator(torch.dstack(nets_outputs))
|
||||||
|
# net_output_size = output_variables * hidden_size
|
||||||
|
aggregated_reshaped = aggregated.view(
|
||||||
|
(len(x), len(self.output_variables), -1)
|
||||||
|
)
|
||||||
|
output_ = self._reduction(aggregated_reshaped)
|
||||||
|
output_ = torch.squeeze(torch.atleast_3d(output_), dim=2)
|
||||||
|
|
||||||
trunk_output = self.trunk_net(
|
assert output_.shape == (len(x), len(self.output_variables))
|
||||||
x.extract(self.trunk_net.input_variables))
|
|
||||||
|
|
||||||
output_ = self.reduction(trunk_output * branch_output)
|
|
||||||
|
|
||||||
output_ = output_.as_subclass(LabelTensor)
|
output_ = output_.as_subclass(LabelTensor)
|
||||||
output_.labels = self.output_variables
|
output_.labels = self.output_variables
|
||||||
|
|
||||||
return output_
|
return output_
|
||||||
|
|||||||
@@ -26,9 +26,11 @@ class FeedForward(torch.nn.Module):
|
|||||||
`inner_size` are not considered.
|
`inner_size` are not considered.
|
||||||
:param iterable(torch.nn.Module) extra_features: the additional input
|
:param iterable(torch.nn.Module) extra_features: the additional input
|
||||||
features to use ad augmented input.
|
features to use ad augmented input.
|
||||||
|
:param bool bias: If `True` the MLP will consider some bias.
|
||||||
"""
|
"""
|
||||||
def __init__(self, input_variables, output_variables, inner_size=20,
|
def __init__(self, input_variables, output_variables, inner_size=20,
|
||||||
n_layers=2, func=nn.Tanh, layers=None, extra_features=None):
|
n_layers=2, func=nn.Tanh, layers=None, extra_features=None,
|
||||||
|
bias=True):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -62,7 +64,9 @@ class FeedForward(torch.nn.Module):
|
|||||||
|
|
||||||
self.layers = []
|
self.layers = []
|
||||||
for i in range(len(tmp_layers)-1):
|
for i in range(len(tmp_layers)-1):
|
||||||
self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1]))
|
self.layers.append(
|
||||||
|
nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias)
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(func, list):
|
if isinstance(func, list):
|
||||||
self.functions = func
|
self.functions = func
|
||||||
@@ -94,7 +98,7 @@ class FeedForward(torch.nn.Module):
|
|||||||
if self.input_variables:
|
if self.input_variables:
|
||||||
x = x.extract(self.input_variables)
|
x = x.extract(self.input_variables)
|
||||||
|
|
||||||
for i, feature in enumerate(self.extra_features):
|
for feature in self.extra_features:
|
||||||
x = x.append(feature(x))
|
x = x.append(feature(x))
|
||||||
|
|
||||||
output = self.model(x).as_subclass(LabelTensor)
|
output = self.model(x).as_subclass(LabelTensor)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Utils module"""
|
"""Utils module"""
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
import types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, default_collate, ConcatDataset
|
from torch.utils.data import DataLoader, default_collate, ConcatDataset
|
||||||
|
|
||||||
@@ -85,6 +87,17 @@ def torch_lhs(n, dim):
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def is_function(f):
|
||||||
|
"""
|
||||||
|
Checks whether the given object `f` is a function or lambda.
|
||||||
|
|
||||||
|
:param object f: The object to be checked.
|
||||||
|
:return: `True` if `f` is a function, `False` otherwise.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return type(f) == types.FunctionType or type(f) == types.LambdaType
|
||||||
|
|
||||||
|
|
||||||
class PinaDataset():
|
class PinaDataset():
|
||||||
|
|
||||||
def __init__(self, pinn) -> None:
|
def __init__(self, pinn) -> None:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
from pina.model import DeepONet, FeedForward as FFN
|
from pina.model import DeepONet
|
||||||
|
from pina.model import FeedForward as FFN
|
||||||
|
|
||||||
data = torch.rand((20, 3))
|
data = torch.rand((20, 3))
|
||||||
input_vars = ['a', 'b', 'c']
|
input_vars = ['a', 'b', 'c']
|
||||||
@@ -14,19 +14,17 @@ input_ = LabelTensor(data, input_vars)
|
|||||||
def test_constructor():
|
def test_constructor():
|
||||||
branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
||||||
trunk = FFN(input_variables=['b'], output_variables=20)
|
trunk = FFN(input_variables=['b'], output_variables=20)
|
||||||
onet = DeepONet(trunk_net=trunk, branch_net=branch,
|
onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
|
||||||
output_variables=output_vars)
|
|
||||||
|
|
||||||
def test_constructor_fails_when_invalid_inner_layer_size():
|
def test_constructor_fails_when_invalid_inner_layer_size():
|
||||||
branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
||||||
trunk = FFN(input_variables=['b'], output_variables=19)
|
trunk = FFN(input_variables=['b'], output_variables=19)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
DeepONet(trunk_net=trunk, branch_net=branch, output_variables=output_vars)
|
DeepONet(nets=[trunk, branch], output_variables=output_vars)
|
||||||
|
|
||||||
def test_forward():
|
def test_forward():
|
||||||
branch = FFN(input_variables=['a', 'c'], output_variables=10)
|
branch = FFN(input_variables=['a', 'c'], output_variables=10)
|
||||||
trunk = FFN(input_variables=['b'], output_variables=10)
|
trunk = FFN(input_variables=['b'], output_variables=10)
|
||||||
onet = DeepONet(trunk_net=trunk, branch_net=branch,
|
onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
|
||||||
output_variables=output_vars)
|
|
||||||
output_ = onet(input_)
|
output_ = onet(input_)
|
||||||
assert output_.labels == output_vars
|
assert output_.labels == output_vars
|
||||||
|
|||||||
Reference in New Issue
Block a user