add docs
This commit is contained in:
@@ -12,7 +12,7 @@ from pina.utils import is_function
|
||||
|
||||
def check_combos(combos, variables):
|
||||
"""
|
||||
Check that the given combinations are subsets (overlapping
|
||||
Check that the given combinations are subsets (overlapping
|
||||
is allowed) of the given set of variables.
|
||||
|
||||
:param iterable(iterable(str)) combos: Combinations of variables.
|
||||
@@ -35,7 +35,7 @@ def spawn_combo_networks(
|
||||
:param iterable(iterable(str)) combos: Combinations of variables.
|
||||
:param iterable(int) layers: Size of hidden layers.
|
||||
:param int output_dimension: Size of the output layer of the networks.
|
||||
:param func: Nonlinearity.
|
||||
:param func: Nonlinearity.
|
||||
:param extra_feature: Extra feature to be considered by the networks.
|
||||
:param bool bias: Whether to consider bias or not.
|
||||
"""
|
||||
@@ -78,15 +78,16 @@ class DeepONet(torch.nn.Module):
|
||||
:param list(str) output_variables: the list containing the labels
|
||||
corresponding to the components of the output computed by the
|
||||
model.
|
||||
:param string | callable aggregator: Aggregator to be used to aggregate
|
||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. See :func:`_symbol_functions` for the
|
||||
aggregated component-wise. See
|
||||
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
|
||||
available default aggregators.
|
||||
:param string | callable reduction: Reduction to be used to reduce
|
||||
:param str | callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. See :func:`_symbol_functions` for the available default
|
||||
reductions.
|
||||
|
||||
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
|
||||
reductions.
|
||||
|
||||
:Example:
|
||||
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
||||
>>> trunk = FFN(input_variables=['b'], output_variables=20)
|
||||
@@ -127,9 +128,15 @@ class DeepONet(torch.nn.Module):
|
||||
raise ValueError("All networks should have the same output size")
|
||||
self._nets = torch.nn.ModuleList(nets)
|
||||
logging.info("Combo DeepONet children: %s", list(self.children()))
|
||||
self.scale = torch.nn.Parameter(torch.tensor([1.0]))
|
||||
self.trasl = torch.nn.Parameter(torch.tensor([0.0]))
|
||||
|
||||
@staticmethod
|
||||
def _symbol_functions(**kwargs):
|
||||
"""
|
||||
Return a dictionary of functions that can be used as aggregators or
|
||||
reductions.
|
||||
"""
|
||||
return {
|
||||
"+": partial(torch.sum, **kwargs),
|
||||
"*": partial(torch.prod, **kwargs),
|
||||
@@ -215,4 +222,7 @@ class DeepONet(torch.nn.Module):
|
||||
|
||||
output_ = output_.as_subclass(LabelTensor)
|
||||
output_.labels = self.output_variables
|
||||
|
||||
output_ *= self.scale
|
||||
output_ += self.trasl
|
||||
return output_
|
||||
|
||||
Reference in New Issue
Block a user