Change default reduction in SystemEquation (#317)

* Update system_equation.py
* Update test_systemequation.py
This commit is contained in:
luAndre00
2024-08-01 17:47:51 +02:00
committed by GitHub
parent f9316e359a
commit 16261c9baf
2 changed files with 13 additions and 10 deletions

View File

@@ -7,7 +7,7 @@ from ..utils import check_consistency
class SystemEquation(Equation):
def __init__(self, list_equation, reduction="mean"):
def __init__(self, list_equation, reduction=None):
"""
System of Equation class for specifing any system
of equations in PINA.
@@ -19,14 +19,13 @@ class SystemEquation(Equation):
:param Callable equation: A ``torch`` callable equation to
evaluate the residual
:param str reduction: Specifies the reduction to apply to the output:
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
will be applied, ``mean``: the sum of the output will be divided
None | ``mean`` | ``sum`` | callable. None: no reduction
will be applied, ``mean``: the output sum will be divided
by the number of elements in the output, ``sum``: the output will
be summed. ``callable`` a callable function to perform reduction,
no checks guaranteed. Default: ``mean``.
be summed. *callable* is a callable function to perform reduction,
no checks guaranteed. Default: None.
"""
check_consistency([list_equation], list)
check_consistency(reduction, str)
# equations definition
self.equations = []
@@ -38,7 +37,7 @@ class SystemEquation(Equation):
self.reduction = torch.mean
elif reduction == "sum":
self.reduction = torch.sum
elif (reduction == "none") or callable(reduction):
elif (reduction == None) or callable(reduction):
self.reduction = reduction
else:
raise NotImplementedError(
@@ -72,7 +71,7 @@ class SystemEquation(Equation):
]
)
if self.reduction == "none":
if self.reduction is None:
return residual
return self.reduction(residual, dim=-1)

View File

@@ -39,7 +39,7 @@ def test_residual():
u = torch.pow(pts, 2)
u.labels = ['u1', 'u2']
eq_1 = SystemEquation([eq1, eq2])
eq_1 = SystemEquation([eq1, eq2], reduction='mean')
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10])
@@ -47,6 +47,10 @@ def test_residual():
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10])
eq_1 = SystemEquation([eq1, eq2], reduction='none')
eq_1 = SystemEquation([eq1, eq2], reduction=None)
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10, 3])
eq_1 = SystemEquation([eq1, eq2])
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10, 3])