This commit is contained in:
Your Name
2023-04-18 10:49:57 +02:00
parent da33aeae3a
commit 736c78fd64
17 changed files with 292 additions and 172 deletions

View File

@@ -12,7 +12,7 @@ from pina.utils import is_function
def check_combos(combos, variables):
"""
Check that the given combinations are subsets (overlapping
Check that the given combinations are subsets (overlapping
is allowed) of the given set of variables.
:param iterable(iterable(str)) combos: Combinations of variables.
@@ -35,7 +35,7 @@ def spawn_combo_networks(
:param iterable(iterable(str)) combos: Combinations of variables.
:param iterable(int) layers: Size of hidden layers.
:param int output_dimension: Size of the output layer of the networks.
:param func: Nonlinearity.
:param func: Nonlinearity.
:param extra_feature: Extra feature to be considered by the networks.
:param bool bias: Whether to consider bias or not.
"""
@@ -78,15 +78,16 @@ class DeepONet(torch.nn.Module):
:param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the
model.
:param string | callable aggregator: Aggregator to be used to aggregate
:param str | callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. See :func:`_symbol_functions` for the
aggregated component-wise. See
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
available default aggregators.
:param string | callable reduction: Reduction to be used to reduce
:param str | callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. See :func:`_symbol_functions` for the available default
reductions.
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
reductions.
:Example:
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
>>> trunk = FFN(input_variables=['b'], output_variables=20)
@@ -127,9 +128,15 @@ class DeepONet(torch.nn.Module):
raise ValueError("All networks should have the same output size")
self._nets = torch.nn.ModuleList(nets)
logging.info("Combo DeepONet children: %s", list(self.children()))
self.scale = torch.nn.Parameter(torch.tensor([1.0]))
self.trasl = torch.nn.Parameter(torch.tensor([0.0]))
@staticmethod
def _symbol_functions(**kwargs):
"""
Return a dictionary of functions that can be used as aggregators or
reductions.
"""
return {
"+": partial(torch.sum, **kwargs),
"*": partial(torch.prod, **kwargs),
@@ -215,4 +222,7 @@ class DeepONet(torch.nn.Module):
output_ = output_.as_subclass(LabelTensor)
output_.labels = self.output_variables
output_ *= self.scale
output_ += self.trasl
return output_