Files
PINA/pina/model/deeponet.py
Your Name 736c78fd64 add docs
2023-04-18 10:55:42 +02:00

229 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Module for DeepONet model"""
import logging
from functools import partial, reduce
import torch
import torch.nn as nn
from pina import LabelTensor
from pina.model import FeedForward
from pina.utils import is_function
def check_combos(combos, variables):
"""
Check that the given combinations are subsets (overlapping
is allowed) of the given set of variables.
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(str) variables: Variables.
"""
for combo in combos:
for variable in combo:
if variable not in variables:
raise ValueError(
f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable"
)
def spawn_combo_networks(
combos, layers, output_dimension, func, extra_feature, bias=True
):
"""
Spawn internal networks for DeepONet based on the given combos.
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(int) layers: Size of hidden layers.
:param int output_dimension: Size of the output layer of the networks.
:param func: Nonlinearity.
:param extra_feature: Extra feature to be considered by the networks.
:param bool bias: Whether to consider bias or not.
"""
if is_function(extra_feature):
extra_feature_func = lambda _: extra_feature
else:
extra_feature_func = extra_feature
return [
FeedForward(
layers=layers,
input_variables=tuple(combo),
output_variables=output_dimension,
func=func,
extra_features=extra_feature_func(combo),
bias=bias,
)
for combo in combos
]
class DeepONet(torch.nn.Module):
"""
The PINA implementation of DeepONet network.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
nonlinear operators via DeepONet based on the universal approximation
theorem of operators*. Nat Mach Intell 3, 218229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
"""
def __init__(self, nets, output_variables, aggregator="*", reduction="+"):
"""
:param iterable(torch.nn.Module) nets: Internal DeepONet networks
(branch and trunk in the original DeepONet).
:param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the
model.
:param str | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. See
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
available default aggregators.
:param str | callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
reductions.
:Example:
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
>>> trunk = FFN(input_variables=['b'], output_variables=20)
>>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
DeepONet(
(trunk_net): FeedForward(
(extra_features): Sequential()
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=20, bias=True)
)
)
(branch_net): FeedForward(
(extra_features): Sequential()
(model): Sequential(
(0): Linear(in_features=2, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=20, bias=True)
)
)
)
"""
super().__init__()
self.output_variables = output_variables
self.output_dimension = len(output_variables)
self._init_aggregator(aggregator, n_nets=len(nets))
hidden_size = nets[0].model[-1].out_features
self._init_reduction(reduction, hidden_size=hidden_size)
if not DeepONet._all_nets_same_output_layer_size(nets):
raise ValueError("All networks should have the same output size")
self._nets = torch.nn.ModuleList(nets)
logging.info("Combo DeepONet children: %s", list(self.children()))
self.scale = torch.nn.Parameter(torch.tensor([1.0]))
self.trasl = torch.nn.Parameter(torch.tensor([0.0]))
@staticmethod
def _symbol_functions(**kwargs):
"""
Return a dictionary of functions that can be used as aggregators or
reductions.
"""
return {
"+": partial(torch.sum, **kwargs),
"*": partial(torch.prod, **kwargs),
"mean": partial(torch.mean, **kwargs),
"min": lambda x: torch.min(x, **kwargs).values,
"max": lambda x: torch.max(x, **kwargs).values,
}
def _init_aggregator(self, aggregator, n_nets):
aggregator_funcs = DeepONet._symbol_functions(dim=2)
if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator]
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
aggregator_func = aggregator
elif aggregator == "linear":
aggregator_func = nn.Linear(n_nets, len(self.output_variables))
else:
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
self._aggregator = aggregator_func
logging.info("Selected aggregator: %s", str(aggregator_func))
# test the aggregator
test = self._aggregator(torch.ones((20, 3, n_nets)))
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
raise ValueError(
f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}"
)
def _init_reduction(self, reduction, hidden_size):
reduction_funcs = DeepONet._symbol_functions(dim=2)
if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction]
elif isinstance(reduction, nn.Module) or is_function(reduction):
reduction_func = reduction
elif reduction == "linear":
reduction_func = nn.Linear(hidden_size, len(self.output_variables))
else:
raise ValueError(f"Unsupported reduction: {reduction}")
self._reduction = reduction_func
logging.info("Selected reduction: %s", str(reduction))
# test the reduction
test = self._reduction(torch.ones((20, 3, hidden_size)))
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}"
raise ValueError(msg)
@staticmethod
def _all_nets_same_output_layer_size(nets):
size = nets[0].layers[-1].out_features
return all((net.layers[-1].out_features == size for net in nets[1:]))
@property
def input_variables(self):
"""The input variables of the model"""
nets_input_variables = map(lambda net: net.input_variables, self._nets)
return reduce(sum, nets_input_variables)
def forward(self, x):
"""
Defines the computation performed at every call.
:param LabelTensor x: the input tensor.
:return: the output computed by the model.
:rtype: LabelTensor
"""
nets_outputs = tuple(
net(x.extract(net.input_variables)) for net in self._nets
)
# torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets)
aggregated = self._aggregator(torch.dstack(nets_outputs))
# net_output_size = output_variables * hidden_size
aggregated_reshaped = aggregated.view(
(len(x), len(self.output_variables), -1)
)
output_ = self._reduction(aggregated_reshaped)
output_ = torch.squeeze(torch.atleast_3d(output_), dim=2)
assert output_.shape == (len(x), len(self.output_variables))
output_ = output_.as_subclass(LabelTensor)
output_.labels = self.output_variables
output_ *= self.scale
output_ += self.trasl
return output_