Refactoring solvers (#541)

* Refactoring solvers

* Simplify logic compile
* Improve and update doc
* Create SupervisedSolverInterface
* Specialize SupervisedSolver and ReducedOrderModelSolver
* Create EnsembleSolverInterface + EnsembleSupervisedSolver
* Create tests ensemble solvers

* formatter

* codacy

* fix issues + speedup test
This commit is contained in:
Dario Coscia
2025-04-09 14:51:42 +02:00
parent 485c8dd789
commit 6dd7bd2825
37 changed files with 1514 additions and 510 deletions

View File

@@ -72,18 +72,22 @@ def labelize_forward(forward, input_variables, output_variables):
:rtype: Callable
"""
def wrapper(x):
def wrapper(x, *args, **kwargs):
"""
Decorated forward function.
:param LabelTensor x: The labelized input of the forward pass of an
instance of :class:`torch.nn.Module`.
:param Iterable args: Additional positional arguments passed to
``forward`` method.
:param dict kwargs: Additional keyword arguments passed to
``forward`` method.
:return: The labelized output of the forward pass of an instance of
:class:`torch.nn.Module`.
:rtype: LabelTensor
"""
x = x.extract(input_variables)
output = forward(x)
output = forward(x, *args, **kwargs)
# keep it like this, directly using LabelTensor(...) raises errors
# when compiling the code
output = output.as_subclass(LabelTensor)