From 16261c9baff1568a3261cec5745c08b25ffdb859 Mon Sep 17 00:00:00 2001 From: luAndre00 <130570716+luAndre00@users.noreply.github.com> Date: Thu, 1 Aug 2024 17:47:51 +0200 Subject: [PATCH] Change default reduction in SystemEquation (#317) * Update system_equation.py * Update test_systemequation.py --- pina/equation/system_equation.py | 15 +++++++-------- tests/test_equations/test_systemequation.py | 8 ++++++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pina/equation/system_equation.py b/pina/equation/system_equation.py index 16d8c46..bf54abd 100644 --- a/pina/equation/system_equation.py +++ b/pina/equation/system_equation.py @@ -7,7 +7,7 @@ from ..utils import check_consistency class SystemEquation(Equation): - def __init__(self, list_equation, reduction="mean"): + def __init__(self, list_equation, reduction=None): """ System of Equation class for specifing any system of equations in PINA. @@ -19,14 +19,13 @@ class SystemEquation(Equation): :param Callable equation: A ``torch`` callable equation to evaluate the residual :param str reduction: Specifies the reduction to apply to the output: - ``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction - will be applied, ``mean``: the sum of the output will be divided + None | ``mean`` | ``sum`` | callable. None: no reduction + will be applied, ``mean``: the output sum will be divided by the number of elements in the output, ``sum``: the output will - be summed. ``callable`` a callable function to perform reduction, - no checks guaranteed. Default: ``mean``. + be summed. *callable* is a callable function to perform reduction, + no checks guaranteed. Default: None. """ check_consistency([list_equation], list) - check_consistency(reduction, str) # equations definition self.equations = [] @@ -38,7 +37,7 @@ class SystemEquation(Equation): self.reduction = torch.mean elif reduction == "sum": self.reduction = torch.sum - elif (reduction == "none") or callable(reduction): + elif (reduction == None) or callable(reduction): self.reduction = reduction else: raise NotImplementedError( @@ -72,7 +71,7 @@ class SystemEquation(Equation): ] ) - if self.reduction == "none": + if self.reduction is None: return residual return self.reduction(residual, dim=-1) diff --git a/tests/test_equations/test_systemequation.py b/tests/test_equations/test_systemequation.py index ae6825b..7af90a7 100644 --- a/tests/test_equations/test_systemequation.py +++ b/tests/test_equations/test_systemequation.py @@ -39,7 +39,7 @@ def test_residual(): u = torch.pow(pts, 2) u.labels = ['u1', 'u2'] - eq_1 = SystemEquation([eq1, eq2]) + eq_1 = SystemEquation([eq1, eq2], reduction='mean') res = eq_1.residual(pts, u) assert res.shape == torch.Size([10]) @@ -47,6 +47,10 @@ def test_residual(): res = eq_1.residual(pts, u) assert res.shape == torch.Size([10]) - eq_1 = SystemEquation([eq1, eq2], reduction='none') + eq_1 = SystemEquation([eq1, eq2], reduction=None) + res = eq_1.residual(pts, u) + assert res.shape == torch.Size([10, 3]) + + eq_1 = SystemEquation([eq1, eq2]) res = eq_1.residual(pts, u) assert res.shape == torch.Size([10, 3])