add callable function as reduction
This commit is contained in:
committed by
Nicola Demo
parent
c4df59ebcd
commit
44cf800491
@@ -17,12 +17,11 @@ class SystemEquation(Equation):
|
|||||||
:param callable equation: A ``torch`` callable equation to
|
:param callable equation: A ``torch`` callable equation to
|
||||||
evaluate the residual
|
evaluate the residual
|
||||||
:param str reduction: Specifies the reduction to apply to the output:
|
:param str reduction: Specifies the reduction to apply to the output:
|
||||||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
|
||||||
will be applied, ``'mean'``: the sum of the output will be divided
|
will be applied, ``mean``: the sum of the output will be divided
|
||||||
by the number of elements in the output, ``'sum'``: the output will
|
by the number of elements in the output, ``sum``: the output will
|
||||||
be summed. Note: :attr:`size_average` and :attr:`reduce` are in the
|
be summed. ``callable`` a callable function to perform reduction,
|
||||||
process of being deprecated, and in the meantime, specifying either of
|
no checks guaranteed. Default: ``mean``.
|
||||||
those two args will override :attr:`reduction`. Default: ``'mean'``.
|
|
||||||
"""
|
"""
|
||||||
check_consistency([list_equation], list)
|
check_consistency([list_equation], list)
|
||||||
check_consistency(reduction, str)
|
check_consistency(reduction, str)
|
||||||
@@ -37,7 +36,7 @@ class SystemEquation(Equation):
|
|||||||
self.reduction = torch.mean
|
self.reduction = torch.mean
|
||||||
elif reduction == 'sum':
|
elif reduction == 'sum':
|
||||||
self.reduction = torch.sum
|
self.reduction = torch.sum
|
||||||
elif reduction == 'none':
|
elif (reduction == 'none') or callable(reduction):
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Only mean and sum reductions implemented.')
|
raise NotImplementedError('Only mean and sum reductions implemented.')
|
||||||
|
|||||||
Reference in New Issue
Block a user