Refactoring solvers (#541)
* Refactoring solvers * Simplify logic compile * Improve and update doc * Create SupervisedSolverInterface * Specialize SupervisedSolver and ReducedOrderModelSolver * Create EnsembleSolverInterface + EnsembleSupervisedSolver * Create tests ensemble solvers * formatter * codacy * fix issues + speedup test
This commit is contained in:
@@ -72,18 +72,22 @@ def labelize_forward(forward, input_variables, output_variables):
|
||||
:rtype: Callable
|
||||
"""
|
||||
|
||||
def wrapper(x):
|
||||
def wrapper(x, *args, **kwargs):
|
||||
"""
|
||||
Decorated forward function.
|
||||
|
||||
:param LabelTensor x: The labelized input of the forward pass of an
|
||||
instance of :class:`torch.nn.Module`.
|
||||
:param Iterable args: Additional positional arguments passed to
|
||||
``forward`` method.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``forward`` method.
|
||||
:return: The labelized output of the forward pass of an instance of
|
||||
:class:`torch.nn.Module`.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
x = x.extract(input_variables)
|
||||
output = forward(x)
|
||||
output = forward(x, *args, **kwargs)
|
||||
# keep it like this, directly using LabelTensor(...) raises errors
|
||||
# when compiling the code
|
||||
output = output.as_subclass(LabelTensor)
|
||||
|
||||
Reference in New Issue
Block a user