add callable function as reduction

This commit is contained in:
Dario Coscia
2023-06-29 18:31:26 +02:00
committed by Nicola Demo
parent c4df59ebcd
commit 44cf800491

View File

@@ -17,12 +17,11 @@ class SystemEquation(Equation):
:param callable equation: A ``torch`` callable equation to :param callable equation: A ``torch`` callable equation to
evaluate the residual evaluate the residual
:param str reduction: Specifies the reduction to apply to the output: :param str reduction: Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction ``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
will be applied, ``'mean'``: the sum of the output will be divided will be applied, ``mean``: the sum of the output will be divided
by the number of elements in the output, ``'sum'``: the output will by the number of elements in the output, ``sum``: the output will
be summed. Note: :attr:`size_average` and :attr:`reduce` are in the be summed. ``callable`` a callable function to perform reduction,
process of being deprecated, and in the meantime, specifying either of no checks guaranteed. Default: ``mean``.
those two args will override :attr:`reduction`. Default: ``'mean'``.
""" """
check_consistency([list_equation], list) check_consistency([list_equation], list)
check_consistency(reduction, str) check_consistency(reduction, str)
@@ -37,7 +36,7 @@ class SystemEquation(Equation):
self.reduction = torch.mean self.reduction = torch.mean
elif reduction == 'sum': elif reduction == 'sum':
self.reduction = torch.sum self.reduction = torch.sum
elif reduction == 'none': elif (reduction == 'none') or callable(reduction):
self.reduction = reduction self.reduction = reduction
else: else:
raise NotImplementedError('Only mean and sum reductions implemented.') raise NotImplementedError('Only mean and sum reductions implemented.')