DeepOnet implementation, LabelTensor modification
* Implementing standard DeepOnet (trunk/branch net) * Implementing multiple reduction/ average techniques * Small change LabelTensor __getitem__ for handling list
This commit is contained in:
committed by
Nicola Demo
parent
15ecaacb7c
commit
b029f18c49
@@ -212,7 +212,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
if selected_lt.ndim == 1:
|
if selected_lt.ndim == 1:
|
||||||
selected_lt = selected_lt.reshape(-1, 1)
|
selected_lt = selected_lt.reshape(-1, 1)
|
||||||
if hasattr(self, 'labels'):
|
if hasattr(self, 'labels'):
|
||||||
selected_lt.labels = self.labels[index[1]]
|
if isinstance(index[1], list):
|
||||||
|
selected_lt.labels = [self.labels[i] for i in index[1]]
|
||||||
|
else:
|
||||||
|
selected_lt.labels = self.labels[index[1]]
|
||||||
|
|
||||||
return selected_lt
|
return selected_lt
|
||||||
|
|
||||||
|
|||||||
@@ -1,60 +1,9 @@
|
|||||||
"""Module for DeepONet model"""
|
"""Module for DeepONet model"""
|
||||||
import logging
|
|
||||||
from functools import partial, reduce
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from ..utils import check_consistency, is_function
|
||||||
|
from functools import partial, reduce
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
from pina.model import FeedForward
|
|
||||||
from pina.utils import is_function
|
|
||||||
|
|
||||||
|
|
||||||
def check_combos(combos, variables):
|
|
||||||
"""
|
|
||||||
Check that the given combinations are subsets (overlapping
|
|
||||||
is allowed) of the given set of variables.
|
|
||||||
|
|
||||||
:param iterable(iterable(str)) combos: Combinations of variables.
|
|
||||||
:param iterable(str) variables: Variables.
|
|
||||||
"""
|
|
||||||
for combo in combos:
|
|
||||||
for variable in combo:
|
|
||||||
if variable not in variables:
|
|
||||||
raise ValueError(
|
|
||||||
f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def spawn_combo_networks(
|
|
||||||
combos, layers, output_dimension, func, extra_feature, bias=True
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Spawn internal networks for DeepONet based on the given combos.
|
|
||||||
|
|
||||||
:param iterable(iterable(str)) combos: Combinations of variables.
|
|
||||||
:param iterable(int) layers: Size of hidden layers.
|
|
||||||
:param int output_dimension: Size of the output layer of the networks.
|
|
||||||
:param func: Nonlinearity.
|
|
||||||
:param extra_feature: Extra feature to be considered by the networks.
|
|
||||||
:param bool bias: Whether to consider bias or not.
|
|
||||||
"""
|
|
||||||
if is_function(extra_feature):
|
|
||||||
extra_feature_func = lambda _: extra_feature
|
|
||||||
else:
|
|
||||||
extra_feature_func = extra_feature
|
|
||||||
|
|
||||||
return [
|
|
||||||
FeedForward(
|
|
||||||
layers=layers,
|
|
||||||
input_variables=tuple(combo),
|
|
||||||
output_variables=output_dimension,
|
|
||||||
func=func,
|
|
||||||
extra_features=extra_feature_func(combo),
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
for combo in combos
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class DeepONet(torch.nn.Module):
|
class DeepONet(torch.nn.Module):
|
||||||
@@ -70,14 +19,32 @@ class DeepONet(torch.nn.Module):
|
|||||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, nets, output_variables, aggregator="*", reduction="+"):
|
branch_net,
|
||||||
|
trunk_net,
|
||||||
|
input_indeces_branch_net,
|
||||||
|
input_indeces_trunk_net,
|
||||||
|
aggregator="*",
|
||||||
|
reduction="+"):
|
||||||
"""
|
"""
|
||||||
:param iterable(torch.nn.Module) nets: Internal DeepONet networks
|
:param torch.nn.Module branch_net: The neural network to use as branch
|
||||||
(branch and trunk in the original DeepONet).
|
model. It has to take as input a :class:`LabelTensor`
|
||||||
:param list(str) output_variables: the list containing the labels
|
or :class:`torch.Tensor`. The number of dimensions of the output has
|
||||||
corresponding to the components of the output computed by the
|
to be the same of the ``trunk_net``.
|
||||||
model.
|
:param torch.nn.Module trunk_net: The neural network to use as trunk
|
||||||
|
model. It has to take as input a :class:`LabelTensor`
|
||||||
|
or :class:`torch.Tensor`. The number of dimensions of the output
|
||||||
|
has to be the same of the ``branch_net``.
|
||||||
|
:param list(int) | list(str) input_indeces_branch_net: List of indeces
|
||||||
|
to extract from the input variable in the forward pass for the
|
||||||
|
branch net. If a list of ``int`` is passed, the corresponding columns
|
||||||
|
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||||
|
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||||
|
:param list(int) | list(str) input_indeces_trunk_net: List of indeces
|
||||||
|
to extract from the input variable in the forward pass for the
|
||||||
|
trunk net. If a list of ``int`` is passed, the corresponding columns
|
||||||
|
of the inner most entries are extracted. If a list of ``str`` is passed
|
||||||
|
the variables of the corresponding :class:`LabelTensor` are extracted.
|
||||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||||
partial results from the modules in `nets`. Partial results are
|
partial results from the modules in `nets`. Partial results are
|
||||||
aggregated component-wise. See
|
aggregated component-wise. See
|
||||||
@@ -89,47 +56,64 @@ class DeepONet(torch.nn.Module):
|
|||||||
reductions.
|
reductions.
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
>>> trunk = FFN(input_variables=['b'], output_variables=20)
|
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
>>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
|
>>> model = DeepONet(branch_net=branch_net,
|
||||||
|
... trunk_net=trunk_net,
|
||||||
|
... input_indeces_branch_net=['x'],
|
||||||
|
... input_indeces_trunk_net=['t'],
|
||||||
|
... reduction='+',
|
||||||
|
... aggregator='*')
|
||||||
|
>>> model
|
||||||
DeepONet(
|
DeepONet(
|
||||||
(trunk_net): FeedForward(
|
(trunk_net): FeedForward(
|
||||||
(extra_features): Sequential()
|
|
||||||
(model): Sequential(
|
(model): Sequential(
|
||||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||||
(1): Tanh()
|
(1): Tanh()
|
||||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
(3): Tanh()
|
(3): Tanh()
|
||||||
(4): Linear(in_features=20, out_features=20, bias=True)
|
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
(branch_net): FeedForward(
|
(branch_net): FeedForward(
|
||||||
(extra_features): Sequential()
|
|
||||||
(model): Sequential(
|
(model): Sequential(
|
||||||
(0): Linear(in_features=2, out_features=20, bias=True)
|
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||||
(1): Tanh()
|
(1): Tanh()
|
||||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||||
(3): Tanh()
|
(3): Tanh()
|
||||||
(4): Linear(in_features=20, out_features=20, bias=True)
|
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.output_variables = output_variables
|
# check type consistency
|
||||||
self.output_dimension = len(output_variables)
|
check_consistency(input_indeces_branch_net, (str, int))
|
||||||
|
check_consistency(input_indeces_trunk_net, (str, int))
|
||||||
|
check_consistency(branch_net, torch.nn.Module)
|
||||||
|
check_consistency(trunk_net, torch.nn.Module)
|
||||||
|
|
||||||
self._init_aggregator(aggregator, n_nets=len(nets))
|
# check trunk branch nets consistency
|
||||||
hidden_size = nets[0].model[-1].out_features
|
trunk_out_dim = trunk_net.layers[-1].out_features
|
||||||
self._init_reduction(reduction, hidden_size=hidden_size)
|
branch_out_dim = branch_net.layers[-1].out_features
|
||||||
|
if trunk_out_dim != branch_out_dim:
|
||||||
|
raise ValueError('Branch and trunk networks have not the same '
|
||||||
|
'output dimension.')
|
||||||
|
|
||||||
if not DeepONet._all_nets_same_output_layer_size(nets):
|
# assign trunk and branch net with their input indeces
|
||||||
raise ValueError("All networks should have the same output size")
|
self.trunk_net = trunk_net
|
||||||
self._nets = torch.nn.ModuleList(nets)
|
self._trunk_indeces = input_indeces_trunk_net
|
||||||
logging.info("Combo DeepONet children: %s", list(self.children()))
|
self.branch_net = branch_net
|
||||||
self.scale = torch.nn.Parameter(torch.tensor([1.0]))
|
self._branch_indeces = input_indeces_branch_net
|
||||||
self.trasl = torch.nn.Parameter(torch.tensor([0.0]))
|
|
||||||
|
# initializie aggregation
|
||||||
|
self._init_aggregator(aggregator=aggregator)
|
||||||
|
self._init_reduction(reduction=reduction)
|
||||||
|
|
||||||
|
# scale and translation
|
||||||
|
self._scale = torch.nn.Parameter(torch.tensor([1.0]))
|
||||||
|
self._trasl = torch.nn.Parameter(torch.tensor([0.0]))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _symbol_functions(**kwargs):
|
def _symbol_functions(**kwargs):
|
||||||
@@ -145,57 +129,37 @@ class DeepONet(torch.nn.Module):
|
|||||||
"max": lambda x: torch.max(x, **kwargs).values,
|
"max": lambda x: torch.max(x, **kwargs).values,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _init_aggregator(self, aggregator, n_nets):
|
def _init_aggregator(self, aggregator):
|
||||||
aggregator_funcs = DeepONet._symbol_functions(dim=2)
|
aggregator_funcs = DeepONet._symbol_functions(dim=2)
|
||||||
if aggregator in aggregator_funcs:
|
if aggregator in aggregator_funcs:
|
||||||
aggregator_func = aggregator_funcs[aggregator]
|
aggregator_func = aggregator_funcs[aggregator]
|
||||||
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
|
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
|
||||||
aggregator_func = aggregator
|
aggregator_func = aggregator
|
||||||
elif aggregator == "linear":
|
|
||||||
aggregator_func = nn.Linear(n_nets, len(self.output_variables))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
|
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
|
||||||
|
|
||||||
self._aggregator = aggregator_func
|
self._aggregator = aggregator_func
|
||||||
logging.info("Selected aggregator: %s", str(aggregator_func))
|
|
||||||
|
|
||||||
# test the aggregator
|
|
||||||
test = self._aggregator(torch.ones((20, 3, n_nets)))
|
|
||||||
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_reduction(self, reduction, hidden_size):
|
def _init_reduction(self, reduction):
|
||||||
reduction_funcs = DeepONet._symbol_functions(dim=2)
|
reduction_funcs = DeepONet._symbol_functions(dim=-1)
|
||||||
if reduction in reduction_funcs:
|
if reduction in reduction_funcs:
|
||||||
reduction_func = reduction_funcs[reduction]
|
reduction_func = reduction_funcs[reduction]
|
||||||
elif isinstance(reduction, nn.Module) or is_function(reduction):
|
elif isinstance(reduction, nn.Module) or is_function(reduction):
|
||||||
reduction_func = reduction
|
reduction_func = reduction
|
||||||
elif reduction == "linear":
|
|
||||||
reduction_func = nn.Linear(hidden_size, len(self.output_variables))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported reduction: {reduction}")
|
raise ValueError(f"Unsupported reduction: {reduction}")
|
||||||
|
|
||||||
self._reduction = reduction_func
|
self._reduction = reduction_func
|
||||||
logging.info("Selected reduction: %s", str(reduction))
|
|
||||||
|
|
||||||
# test the reduction
|
def _get_vars(self, x, indeces):
|
||||||
test = self._reduction(torch.ones((20, 3, hidden_size)))
|
if isinstance(indeces[0], str):
|
||||||
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
|
check_consistency(x, LabelTensor)
|
||||||
msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}"
|
return x.extract(indeces)
|
||||||
raise ValueError(msg)
|
elif isinstance(indeces[0], int):
|
||||||
|
return x[..., indeces]
|
||||||
@staticmethod
|
else:
|
||||||
def _all_nets_same_output_layer_size(nets):
|
raise RuntimeError('Not able to extract right indeces for tensor.')
|
||||||
size = nets[0].layers[-1].out_features
|
|
||||||
return all((net.layers[-1].out_features == size for net in nets[1:]))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_variables(self):
|
|
||||||
"""The input variables of the model"""
|
|
||||||
nets_input_variables = map(lambda net: net.input_variables, self._nets)
|
|
||||||
return reduce(sum, nets_input_variables)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
@@ -206,23 +170,18 @@ class DeepONet(torch.nn.Module):
|
|||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nets_outputs = tuple(
|
# forward pass
|
||||||
net(x.extract(net.input_variables)) for net in self._nets
|
branch_output = self.branch_net(self._get_vars(x, self._branch_indeces))
|
||||||
)
|
trunk_output = self.trunk_net(self._get_vars(x, self._trunk_indeces))
|
||||||
# torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets)
|
|
||||||
aggregated = self._aggregator(torch.dstack(nets_outputs))
|
|
||||||
# net_output_size = output_variables * hidden_size
|
|
||||||
aggregated_reshaped = aggregated.view(
|
|
||||||
(len(x), len(self.output_variables), -1)
|
|
||||||
)
|
|
||||||
output_ = self._reduction(aggregated_reshaped)
|
|
||||||
output_ = torch.squeeze(torch.atleast_3d(output_), dim=2)
|
|
||||||
|
|
||||||
assert output_.shape == (len(x), len(self.output_variables))
|
# aggregation
|
||||||
|
aggregated = self._aggregator(torch.dstack((branch_output, trunk_output)))
|
||||||
|
|
||||||
output_ = output_.as_subclass(LabelTensor)
|
# reduce
|
||||||
output_.labels = self.output_variables
|
output_ = self._reduction(aggregated).reshape(-1, 1)
|
||||||
|
|
||||||
|
# scale and translate
|
||||||
|
output_ *= self._scale
|
||||||
|
output_ += self._trasl
|
||||||
|
|
||||||
output_ *= self.scale
|
|
||||||
output_ += self.trasl
|
|
||||||
return output_
|
return output_
|
||||||
@@ -3,29 +3,53 @@ import torch
|
|||||||
|
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
from pina.model import DeepONet
|
from pina.model import DeepONet
|
||||||
from pina.model import FeedForward as FFN
|
from pina.model import FeedForward
|
||||||
|
|
||||||
data = torch.rand((20, 3))
|
data = torch.rand((20, 3))
|
||||||
input_vars = ['a', 'b', 'c']
|
input_vars = ['a', 'b', 'c']
|
||||||
output_vars = ['d']
|
|
||||||
input_ = LabelTensor(data, input_vars)
|
input_ = LabelTensor(data, input_vars)
|
||||||
|
|
||||||
# TODO
|
|
||||||
|
|
||||||
# def test_constructor():
|
def test_constructor():
|
||||||
# branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
# trunk = FFN(input_variables=['b'], output_variables=20)
|
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||||
# onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
|
DeepONet(branch_net=branch_net,
|
||||||
|
trunk_net=trunk_net,
|
||||||
|
input_indeces_branch_net=['a'],
|
||||||
|
input_indeces_trunk_net=['b', 'c'],
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
|
||||||
# def test_constructor_fails_when_invalid_inner_layer_size():
|
|
||||||
# branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
|
||||||
# trunk = FFN(input_variables=['b'], output_variables=19)
|
|
||||||
# with pytest.raises(ValueError):
|
|
||||||
# DeepONet(nets=[trunk, branch], output_variables=output_vars)
|
|
||||||
|
|
||||||
# def test_forward():
|
def test_constructor_fails_when_invalid_inner_layer_size():
|
||||||
# branch = FFN(input_variables=['a', 'c'], output_variables=10)
|
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
# trunk = FFN(input_variables=['b'], output_variables=10)
|
trunk_net = FeedForward(input_dimensons=2, output_dimensions=8)
|
||||||
# onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
|
with pytest.raises(ValueError):
|
||||||
# output_ = onet(input_)
|
DeepONet(branch_net=branch_net,
|
||||||
# assert output_.labels == output_vars
|
trunk_net=trunk_net,
|
||||||
|
input_indeces_branch_net=['a'],
|
||||||
|
input_indeces_trunk_net=['b', 'c'],
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
|
||||||
|
def test_forward_extract_str():
|
||||||
|
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||||
|
model = DeepONet(branch_net=branch_net,
|
||||||
|
trunk_net=trunk_net,
|
||||||
|
input_indeces_branch_net=['a'],
|
||||||
|
input_indeces_trunk_net=['b', 'c'],
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
model(input_)
|
||||||
|
|
||||||
|
def test_forward_extract_int():
|
||||||
|
branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
|
||||||
|
trunk_net = FeedForward(input_dimensons=2, output_dimensions=10)
|
||||||
|
model = DeepONet(branch_net=branch_net,
|
||||||
|
trunk_net=trunk_net,
|
||||||
|
input_indeces_branch_net=[0],
|
||||||
|
input_indeces_trunk_net=[1, 2],
|
||||||
|
reduction='+',
|
||||||
|
aggregator='*')
|
||||||
|
model(data)
|
||||||
|
|||||||
Reference in New Issue
Block a user