add docs
This commit is contained in:
@@ -12,7 +12,7 @@ from pina.utils import is_function
|
||||
|
||||
def check_combos(combos, variables):
|
||||
"""
|
||||
Check that the given combinations are subsets (overlapping
|
||||
Check that the given combinations are subsets (overlapping
|
||||
is allowed) of the given set of variables.
|
||||
|
||||
:param iterable(iterable(str)) combos: Combinations of variables.
|
||||
@@ -35,7 +35,7 @@ def spawn_combo_networks(
|
||||
:param iterable(iterable(str)) combos: Combinations of variables.
|
||||
:param iterable(int) layers: Size of hidden layers.
|
||||
:param int output_dimension: Size of the output layer of the networks.
|
||||
:param func: Nonlinearity.
|
||||
:param func: Nonlinearity.
|
||||
:param extra_feature: Extra feature to be considered by the networks.
|
||||
:param bool bias: Whether to consider bias or not.
|
||||
"""
|
||||
@@ -78,15 +78,16 @@ class DeepONet(torch.nn.Module):
|
||||
:param list(str) output_variables: the list containing the labels
|
||||
corresponding to the components of the output computed by the
|
||||
model.
|
||||
:param string | callable aggregator: Aggregator to be used to aggregate
|
||||
:param str | callable aggregator: Aggregator to be used to aggregate
|
||||
partial results from the modules in `nets`. Partial results are
|
||||
aggregated component-wise. See :func:`_symbol_functions` for the
|
||||
aggregated component-wise. See
|
||||
:func:`pina.model.deeponet.DeepONet._symbol_functions` for the
|
||||
available default aggregators.
|
||||
:param string | callable reduction: Reduction to be used to reduce
|
||||
:param str | callable reduction: Reduction to be used to reduce
|
||||
the aggregated result of the modules in `nets` to the desired output
|
||||
dimension. See :func:`_symbol_functions` for the available default
|
||||
reductions.
|
||||
|
||||
dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default
|
||||
reductions.
|
||||
|
||||
:Example:
|
||||
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
||||
>>> trunk = FFN(input_variables=['b'], output_variables=20)
|
||||
@@ -127,9 +128,15 @@ class DeepONet(torch.nn.Module):
|
||||
raise ValueError("All networks should have the same output size")
|
||||
self._nets = torch.nn.ModuleList(nets)
|
||||
logging.info("Combo DeepONet children: %s", list(self.children()))
|
||||
self.scale = torch.nn.Parameter(torch.tensor([1.0]))
|
||||
self.trasl = torch.nn.Parameter(torch.tensor([0.0]))
|
||||
|
||||
@staticmethod
|
||||
def _symbol_functions(**kwargs):
|
||||
"""
|
||||
Return a dictionary of functions that can be used as aggregators or
|
||||
reductions.
|
||||
"""
|
||||
return {
|
||||
"+": partial(torch.sum, **kwargs),
|
||||
"*": partial(torch.prod, **kwargs),
|
||||
@@ -215,4 +222,7 @@ class DeepONet(torch.nn.Module):
|
||||
|
||||
output_ = output_.as_subclass(LabelTensor)
|
||||
output_.labels = self.output_variables
|
||||
|
||||
output_ *= self.scale
|
||||
output_ += self.trasl
|
||||
return output_
|
||||
|
||||
@@ -89,8 +89,8 @@ class FeedForward(torch.nn.Module):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
|
||||
:param x: the input tensor.
|
||||
:type x: :class:`pina.LabelTensor`
|
||||
:param x: .
|
||||
:type x: :class:`pina.LabelTensor`
|
||||
:return: the output computed by the model.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
@@ -9,8 +9,9 @@ class MultiFeedForward(torch.nn.Module):
|
||||
:param dict dff_dict: dictionary of FeedForward networks.
|
||||
"""
|
||||
def __init__(self, dff_dict):
|
||||
"""
|
||||
"""
|
||||
'''
|
||||
dff_dict: dict of FeedForward objects
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(dff_dict, dict):
|
||||
|
||||
Reference in New Issue
Block a user