diff --git a/pina/equation/system_equation.py b/pina/equation/system_equation.py index 9f6d78f..df852dc 100644 --- a/pina/equation/system_equation.py +++ b/pina/equation/system_equation.py @@ -17,12 +17,11 @@ class SystemEquation(Equation): :param callable equation: A ``torch`` callable equation to evaluate the residual :param str reduction: Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction - will be applied, ``'mean'``: the sum of the output will be divided - by the number of elements in the output, ``'sum'``: the output will - be summed. Note: :attr:`size_average` and :attr:`reduce` are in the - process of being deprecated, and in the meantime, specifying either of - those two args will override :attr:`reduction`. Default: ``'mean'``. + ``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction + will be applied, ``mean``: the sum of the output will be divided + by the number of elements in the output, ``sum``: the output will + be summed. ``callable`` a callable function to perform reduction, + no checks guaranteed. Default: ``mean``. """ check_consistency([list_equation], list) check_consistency(reduction, str) @@ -37,7 +36,7 @@ class SystemEquation(Equation): self.reduction = torch.mean elif reduction == 'sum': self.reduction = torch.sum - elif reduction == 'none': + elif (reduction == 'none') or callable(reduction): self.reduction = reduction else: raise NotImplementedError('Only mean and sum reductions implemented.')