DeepOnet implementation, LabelTensor modification

* Implementing standard DeepOnet (trunk/branch net)
* Implementing multiple reduction/ average techniques
* Small change  LabelTensor __getitem__ for handling list
This commit is contained in:
Dario Coscia
2023-09-06 12:40:21 +02:00
committed by Nicola Demo
parent 15ecaacb7c
commit b029f18c49
3 changed files with 140 additions and 154 deletions

View File

@@ -212,7 +212,10 @@ class LabelTensor(torch.Tensor):
if selected_lt.ndim == 1: if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(-1, 1) selected_lt = selected_lt.reshape(-1, 1)
if hasattr(self, 'labels'): if hasattr(self, 'labels'):
selected_lt.labels = self.labels[index[1]] if isinstance(index[1], list):
selected_lt.labels = [self.labels[i] for i in index[1]]
else:
selected_lt.labels = self.labels[index[1]]
return selected_lt return selected_lt

View File

@@ -1,60 +1,9 @@
"""Module for DeepONet model""" """Module for DeepONet model"""
import logging
from functools import partial, reduce
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..utils import check_consistency, is_function
from functools import partial, reduce
from pina import LabelTensor from pina import LabelTensor
from pina.model import FeedForward
from pina.utils import is_function
def check_combos(combos, variables):
"""
Check that the given combinations are subsets (overlapping
is allowed) of the given set of variables.
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(str) variables: Variables.
"""
for combo in combos:
for variable in combo:
if variable not in variables:
raise ValueError(
f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable"
)
def spawn_combo_networks(
combos, layers, output_dimension, func, extra_feature, bias=True
):
"""
Spawn internal networks for DeepONet based on the given combos.
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(int) layers: Size of hidden layers.
:param int output_dimension: Size of the output layer of the networks.
:param func: Nonlinearity.
:param extra_feature: Extra feature to be considered by the networks.
:param bool bias: Whether to consider bias or not.
"""
if is_function(extra_feature):
extra_feature_func = lambda _: extra_feature
else:
extra_feature_func = extra_feature
return [
FeedForward(
layers=layers,
input_variables=tuple(combo),
output_variables=output_dimension,
func=func,
extra_features=extra_feature_func(combo),
bias=bias,
)
for combo in combos
]
class DeepONet(torch.nn.Module): class DeepONet(torch.nn.Module):
@@ -70,14 +19,32 @@ class DeepONet(torch.nn.Module):
<https://doi.org/10.1038/s42256-021-00302-5>`_ <https://doi.org/10.1038/s42256-021-00302-5>`_
""" """
def __init__(self,
def __init__(self, nets, output_variables, aggregator="*", reduction="+"): branch_net,
trunk_net,
input_indeces_branch_net,
input_indeces_trunk_net,
aggregator="*",
reduction="+"):
""" """
:param iterable(torch.nn.Module) nets: Internal DeepONet networks :param torch.nn.Module branch_net: The neural network to use as branch
(branch and trunk in the original DeepONet). model. It has to take as input a :class:`LabelTensor`
:param list(str) output_variables: the list containing the labels or :class:`torch.Tensor`. The number of dimensions of the output has
corresponding to the components of the output computed by the to be the same of the ``trunk_net``.
model. :param torch.nn.Module trunk_net: The neural network to use as trunk
model. It has to take as input a :class:`LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output
has to be the same of the ``branch_net``.
:param list(int) | list(str) input_indeces_branch_net: List of indeces
to extract from the input variable in the forward pass for the
branch net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :class:`LabelTensor` are extracted.
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
to extract from the input variable in the forward pass for the
trunk net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :class:`LabelTensor` are extracted.
:param str | callable aggregator: Aggregator to be used to aggregate :param str | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are partial results from the modules in `nets`. Partial results are
aggregated component-wise. See aggregated component-wise. See
@@ -89,47 +56,64 @@ class DeepONet(torch.nn.Module):
reductions. reductions.
:Example: :Example:
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20) >>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> trunk = FFN(input_variables=['b'], output_variables=20) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) >>> model = DeepONet(branch_net=branch_net,
... trunk_net=trunk_net,
... input_indeces_branch_net=['x'],
... input_indeces_trunk_net=['t'],
... reduction='+',
... aggregator='*')
>>> model
DeepONet( DeepONet(
(trunk_net): FeedForward( (trunk_net): FeedForward(
(extra_features): Sequential()
(model): Sequential( (model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True) (0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh() (1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True) (2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh() (3): Tanh()
(4): Linear(in_features=20, out_features=20, bias=True) (4): Linear(in_features=20, out_features=10, bias=True)
) )
) )
(branch_net): FeedForward( (branch_net): FeedForward(
(extra_features): Sequential()
(model): Sequential( (model): Sequential(
(0): Linear(in_features=2, out_features=20, bias=True) (0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh() (1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True) (2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh() (3): Tanh()
(4): Linear(in_features=20, out_features=20, bias=True) (4): Linear(in_features=20, out_features=10, bias=True)
) )
) )
) )
""" """
super().__init__() super().__init__()
self.output_variables = output_variables # check type consistency
self.output_dimension = len(output_variables) check_consistency(input_indeces_branch_net, (str, int))
check_consistency(input_indeces_trunk_net, (str, int))
check_consistency(branch_net, torch.nn.Module)
check_consistency(trunk_net, torch.nn.Module)
self._init_aggregator(aggregator, n_nets=len(nets)) # check trunk branch nets consistency
hidden_size = nets[0].model[-1].out_features trunk_out_dim = trunk_net.layers[-1].out_features
self._init_reduction(reduction, hidden_size=hidden_size) branch_out_dim = branch_net.layers[-1].out_features
if trunk_out_dim != branch_out_dim:
raise ValueError('Branch and trunk networks have not the same '
'output dimension.')
if not DeepONet._all_nets_same_output_layer_size(nets): # assign trunk and branch net with their input indeces
raise ValueError("All networks should have the same output size") self.trunk_net = trunk_net
self._nets = torch.nn.ModuleList(nets) self._trunk_indeces = input_indeces_trunk_net
logging.info("Combo DeepONet children: %s", list(self.children())) self.branch_net = branch_net
self.scale = torch.nn.Parameter(torch.tensor([1.0])) self._branch_indeces = input_indeces_branch_net
self.trasl = torch.nn.Parameter(torch.tensor([0.0]))
# initializie aggregation
self._init_aggregator(aggregator=aggregator)
self._init_reduction(reduction=reduction)
# scale and translation
self._scale = torch.nn.Parameter(torch.tensor([1.0]))
self._trasl = torch.nn.Parameter(torch.tensor([0.0]))
@staticmethod @staticmethod
def _symbol_functions(**kwargs): def _symbol_functions(**kwargs):
@@ -145,57 +129,37 @@ class DeepONet(torch.nn.Module):
"max": lambda x: torch.max(x, **kwargs).values, "max": lambda x: torch.max(x, **kwargs).values,
} }
def _init_aggregator(self, aggregator, n_nets): def _init_aggregator(self, aggregator):
aggregator_funcs = DeepONet._symbol_functions(dim=2) aggregator_funcs = DeepONet._symbol_functions(dim=2)
if aggregator in aggregator_funcs: if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator] aggregator_func = aggregator_funcs[aggregator]
elif isinstance(aggregator, nn.Module) or is_function(aggregator): elif isinstance(aggregator, nn.Module) or is_function(aggregator):
aggregator_func = aggregator aggregator_func = aggregator
elif aggregator == "linear":
aggregator_func = nn.Linear(n_nets, len(self.output_variables))
else: else:
raise ValueError(f"Unsupported aggregation: {str(aggregator)}") raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
self._aggregator = aggregator_func self._aggregator = aggregator_func
logging.info("Selected aggregator: %s", str(aggregator_func))
# test the aggregator
test = self._aggregator(torch.ones((20, 3, n_nets)))
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
raise ValueError(
f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}"
)
def _init_reduction(self, reduction, hidden_size): def _init_reduction(self, reduction):
reduction_funcs = DeepONet._symbol_functions(dim=2) reduction_funcs = DeepONet._symbol_functions(dim=-1)
if reduction in reduction_funcs: if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction] reduction_func = reduction_funcs[reduction]
elif isinstance(reduction, nn.Module) or is_function(reduction): elif isinstance(reduction, nn.Module) or is_function(reduction):
reduction_func = reduction reduction_func = reduction
elif reduction == "linear":
reduction_func = nn.Linear(hidden_size, len(self.output_variables))
else: else:
raise ValueError(f"Unsupported reduction: {reduction}") raise ValueError(f"Unsupported reduction: {reduction}")
self._reduction = reduction_func self._reduction = reduction_func
logging.info("Selected reduction: %s", str(reduction))
# test the reduction def _get_vars(self, x, indeces):
test = self._reduction(torch.ones((20, 3, hidden_size))) if isinstance(indeces[0], str):
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3): check_consistency(x, LabelTensor)
msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}" return x.extract(indeces)
raise ValueError(msg) elif isinstance(indeces[0], int):
return x[..., indeces]
@staticmethod else:
def _all_nets_same_output_layer_size(nets): raise RuntimeError('Not able to extract right indeces for tensor.')
size = nets[0].layers[-1].out_features
return all((net.layers[-1].out_features == size for net in nets[1:]))
@property
def input_variables(self):
"""The input variables of the model"""
nets_input_variables = map(lambda net: net.input_variables, self._nets)
return reduce(sum, nets_input_variables)
def forward(self, x): def forward(self, x):
""" """
@@ -206,23 +170,18 @@ class DeepONet(torch.nn.Module):
:rtype: LabelTensor :rtype: LabelTensor
""" """
nets_outputs = tuple( # forward pass
net(x.extract(net.input_variables)) for net in self._nets branch_output = self.branch_net(self._get_vars(x, self._branch_indeces))
) trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces))
# torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets)
aggregated = self._aggregator(torch.dstack(nets_outputs))
# net_output_size = output_variables * hidden_size
aggregated_reshaped = aggregated.view(
(len(x), len(self.output_variables), -1)
)
output_ = self._reduction(aggregated_reshaped)
output_ = torch.squeeze(torch.atleast_3d(output_), dim=2)
assert output_.shape == (len(x), len(self.output_variables)) # aggregation
aggregated = self._aggregator(torch.dstack((branch_output, trunk_output)))
output_ = output_.as_subclass(LabelTensor) # reduce
output_.labels = self.output_variables output_ = self._reduction(aggregated).reshape(-1, 1)
# scale and translate
output_ *= self._scale
output_ += self._trasl
output_ *= self.scale
output_ += self.trasl
return output_ return output_

View File

@@ -3,29 +3,53 @@ import torch
from pina import LabelTensor from pina import LabelTensor
from pina.model import DeepONet from pina.model import DeepONet
from pina.model import FeedForward as FFN from pina.model import FeedForward
data = torch.rand((20, 3)) data = torch.rand((20, 3))
input_vars = ['a', 'b', 'c'] input_vars = ['a', 'b', 'c']
output_vars = ['d']
input_ = LabelTensor(data, input_vars) input_ = LabelTensor(data, input_vars)
# TODO
# def test_constructor(): def test_constructor():
# branch = FFN(input_variables=['a', 'c'], output_variables=20) branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
# trunk = FFN(input_variables=['b'], output_variables=20) trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
# onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) DeepONet(branch_net=branch_net,
trunk_net=trunk_net,
input_indeces_branch_net=['a'],
input_indeces_trunk_net=['b', 'c'],
reduction='+',
aggregator='*')
# def test_constructor_fails_when_invalid_inner_layer_size():
# branch = FFN(input_variables=['a', 'c'], output_variables=20)
# trunk = FFN(input_variables=['b'], output_variables=19)
# with pytest.raises(ValueError):
# DeepONet(nets=[trunk, branch], output_variables=output_vars)
# def test_forward(): def test_constructor_fails_when_invalid_inner_layer_size():
# branch = FFN(input_variables=['a', 'c'], output_variables=10) branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
# trunk = FFN(input_variables=['b'], output_variables=10) trunk_net = FeedForward(input_dimensons=2, output_dimensions=8)
# onet = DeepONet(nets=[trunk, branch], output_variables=output_vars) with pytest.raises(ValueError):
# output_ = onet(input_) DeepONet(branch_net=branch_net,
# assert output_.labels == output_vars trunk_net=trunk_net,
input_indeces_branch_net=['a'],
input_indeces_trunk_net=['b', 'c'],
reduction='+',
aggregator='*')
def test_forward_extract_str():
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
model = DeepONet(branch_net=branch_net,
trunk_net=trunk_net,
input_indeces_branch_net=['a'],
input_indeces_trunk_net=['b', 'c'],
reduction='+',
aggregator='*')
model(input_)
def test_forward_extract_int():
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
model = DeepONet(branch_net=branch_net,
trunk_net=trunk_net,
input_indeces_branch_net=[0],
input_indeces_trunk_net=[1, 2],
reduction='+',
aggregator='*')
model(data)