Change default reduction in SystemEquation (#317)

* Update system_equation.py
* Update test_systemequation.py
This commit is contained in:
luAndre00
2024-08-01 17:47:51 +02:00
committed by GitHub
parent f9316e359a
commit 16261c9baf
2 changed files with 13 additions and 10 deletions

View File

@@ -7,7 +7,7 @@ from ..utils import check_consistency
class SystemEquation(Equation): class SystemEquation(Equation):
def __init__(self, list_equation, reduction="mean"): def __init__(self, list_equation, reduction=None):
""" """
System of Equation class for specifing any system System of Equation class for specifing any system
of equations in PINA. of equations in PINA.
@@ -19,14 +19,13 @@ class SystemEquation(Equation):
:param Callable equation: A ``torch`` callable equation to :param Callable equation: A ``torch`` callable equation to
evaluate the residual evaluate the residual
:param str reduction: Specifies the reduction to apply to the output: :param str reduction: Specifies the reduction to apply to the output:
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction None | ``mean`` | ``sum`` | callable. None: no reduction
will be applied, ``mean``: the sum of the output will be divided will be applied, ``mean``: the output sum will be divided
by the number of elements in the output, ``sum``: the output will by the number of elements in the output, ``sum``: the output will
be summed. ``callable`` a callable function to perform reduction, be summed. *callable* is a callable function to perform reduction,
no checks guaranteed. Default: ``mean``. no checks guaranteed. Default: None.
""" """
check_consistency([list_equation], list) check_consistency([list_equation], list)
check_consistency(reduction, str)
# equations definition # equations definition
self.equations = [] self.equations = []
@@ -38,7 +37,7 @@ class SystemEquation(Equation):
self.reduction = torch.mean self.reduction = torch.mean
elif reduction == "sum": elif reduction == "sum":
self.reduction = torch.sum self.reduction = torch.sum
elif (reduction == "none") or callable(reduction): elif (reduction == None) or callable(reduction):
self.reduction = reduction self.reduction = reduction
else: else:
raise NotImplementedError( raise NotImplementedError(
@@ -72,7 +71,7 @@ class SystemEquation(Equation):
] ]
) )
if self.reduction == "none": if self.reduction is None:
return residual return residual
return self.reduction(residual, dim=-1) return self.reduction(residual, dim=-1)

View File

@@ -39,7 +39,7 @@ def test_residual():
u = torch.pow(pts, 2) u = torch.pow(pts, 2)
u.labels = ['u1', 'u2'] u.labels = ['u1', 'u2']
eq_1 = SystemEquation([eq1, eq2]) eq_1 = SystemEquation([eq1, eq2], reduction='mean')
res = eq_1.residual(pts, u) res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10]) assert res.shape == torch.Size([10])
@@ -47,6 +47,10 @@ def test_residual():
res = eq_1.residual(pts, u) res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10]) assert res.shape == torch.Size([10])
eq_1 = SystemEquation([eq1, eq2], reduction='none') eq_1 = SystemEquation([eq1, eq2], reduction=None)
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10, 3])
eq_1 = SystemEquation([eq1, eq2])
res = eq_1.residual(pts, u) res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10, 3]) assert res.shape == torch.Size([10, 3])