Change default reduction in SystemEquation (#317)
* Update system_equation.py * Update test_systemequation.py
This commit is contained in:
@@ -7,7 +7,7 @@ from ..utils import check_consistency
|
||||
|
||||
class SystemEquation(Equation):
|
||||
|
||||
def __init__(self, list_equation, reduction="mean"):
|
||||
def __init__(self, list_equation, reduction=None):
|
||||
"""
|
||||
System of Equation class for specifing any system
|
||||
of equations in PINA.
|
||||
@@ -19,14 +19,13 @@ class SystemEquation(Equation):
|
||||
:param Callable equation: A ``torch`` callable equation to
|
||||
evaluate the residual
|
||||
:param str reduction: Specifies the reduction to apply to the output:
|
||||
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
|
||||
will be applied, ``mean``: the sum of the output will be divided
|
||||
None | ``mean`` | ``sum`` | callable. None: no reduction
|
||||
will be applied, ``mean``: the output sum will be divided
|
||||
by the number of elements in the output, ``sum``: the output will
|
||||
be summed. ``callable`` a callable function to perform reduction,
|
||||
no checks guaranteed. Default: ``mean``.
|
||||
be summed. *callable* is a callable function to perform reduction,
|
||||
no checks guaranteed. Default: None.
|
||||
"""
|
||||
check_consistency([list_equation], list)
|
||||
check_consistency(reduction, str)
|
||||
|
||||
# equations definition
|
||||
self.equations = []
|
||||
@@ -38,7 +37,7 @@ class SystemEquation(Equation):
|
||||
self.reduction = torch.mean
|
||||
elif reduction == "sum":
|
||||
self.reduction = torch.sum
|
||||
elif (reduction == "none") or callable(reduction):
|
||||
elif (reduction == None) or callable(reduction):
|
||||
self.reduction = reduction
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -72,7 +71,7 @@ class SystemEquation(Equation):
|
||||
]
|
||||
)
|
||||
|
||||
if self.reduction == "none":
|
||||
if self.reduction is None:
|
||||
return residual
|
||||
|
||||
return self.reduction(residual, dim=-1)
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_residual():
|
||||
u = torch.pow(pts, 2)
|
||||
u.labels = ['u1', 'u2']
|
||||
|
||||
eq_1 = SystemEquation([eq1, eq2])
|
||||
eq_1 = SystemEquation([eq1, eq2], reduction='mean')
|
||||
res = eq_1.residual(pts, u)
|
||||
assert res.shape == torch.Size([10])
|
||||
|
||||
@@ -47,6 +47,10 @@ def test_residual():
|
||||
res = eq_1.residual(pts, u)
|
||||
assert res.shape == torch.Size([10])
|
||||
|
||||
eq_1 = SystemEquation([eq1, eq2], reduction='none')
|
||||
eq_1 = SystemEquation([eq1, eq2], reduction=None)
|
||||
res = eq_1.residual(pts, u)
|
||||
assert res.shape == torch.Size([10, 3])
|
||||
|
||||
eq_1 = SystemEquation([eq1, eq2])
|
||||
res = eq_1.residual(pts, u)
|
||||
assert res.shape == torch.Size([10, 3])
|
||||
|
||||
Reference in New Issue
Block a user