Change default reduction in SystemEquation (#317)
* Update system_equation.py * Update test_systemequation.py
This commit is contained in:
@@ -7,7 +7,7 @@ from ..utils import check_consistency
|
|||||||
|
|
||||||
class SystemEquation(Equation):
|
class SystemEquation(Equation):
|
||||||
|
|
||||||
def __init__(self, list_equation, reduction="mean"):
|
def __init__(self, list_equation, reduction=None):
|
||||||
"""
|
"""
|
||||||
System of Equation class for specifing any system
|
System of Equation class for specifing any system
|
||||||
of equations in PINA.
|
of equations in PINA.
|
||||||
@@ -19,14 +19,13 @@ class SystemEquation(Equation):
|
|||||||
:param Callable equation: A ``torch`` callable equation to
|
:param Callable equation: A ``torch`` callable equation to
|
||||||
evaluate the residual
|
evaluate the residual
|
||||||
:param str reduction: Specifies the reduction to apply to the output:
|
:param str reduction: Specifies the reduction to apply to the output:
|
||||||
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
|
None | ``mean`` | ``sum`` | callable. None: no reduction
|
||||||
will be applied, ``mean``: the sum of the output will be divided
|
will be applied, ``mean``: the output sum will be divided
|
||||||
by the number of elements in the output, ``sum``: the output will
|
by the number of elements in the output, ``sum``: the output will
|
||||||
be summed. ``callable`` a callable function to perform reduction,
|
be summed. *callable* is a callable function to perform reduction,
|
||||||
no checks guaranteed. Default: ``mean``.
|
no checks guaranteed. Default: None.
|
||||||
"""
|
"""
|
||||||
check_consistency([list_equation], list)
|
check_consistency([list_equation], list)
|
||||||
check_consistency(reduction, str)
|
|
||||||
|
|
||||||
# equations definition
|
# equations definition
|
||||||
self.equations = []
|
self.equations = []
|
||||||
@@ -38,7 +37,7 @@ class SystemEquation(Equation):
|
|||||||
self.reduction = torch.mean
|
self.reduction = torch.mean
|
||||||
elif reduction == "sum":
|
elif reduction == "sum":
|
||||||
self.reduction = torch.sum
|
self.reduction = torch.sum
|
||||||
elif (reduction == "none") or callable(reduction):
|
elif (reduction == None) or callable(reduction):
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -72,7 +71,7 @@ class SystemEquation(Equation):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduction == "none":
|
if self.reduction is None:
|
||||||
return residual
|
return residual
|
||||||
|
|
||||||
return self.reduction(residual, dim=-1)
|
return self.reduction(residual, dim=-1)
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def test_residual():
|
|||||||
u = torch.pow(pts, 2)
|
u = torch.pow(pts, 2)
|
||||||
u.labels = ['u1', 'u2']
|
u.labels = ['u1', 'u2']
|
||||||
|
|
||||||
eq_1 = SystemEquation([eq1, eq2])
|
eq_1 = SystemEquation([eq1, eq2], reduction='mean')
|
||||||
res = eq_1.residual(pts, u)
|
res = eq_1.residual(pts, u)
|
||||||
assert res.shape == torch.Size([10])
|
assert res.shape == torch.Size([10])
|
||||||
|
|
||||||
@@ -47,6 +47,10 @@ def test_residual():
|
|||||||
res = eq_1.residual(pts, u)
|
res = eq_1.residual(pts, u)
|
||||||
assert res.shape == torch.Size([10])
|
assert res.shape == torch.Size([10])
|
||||||
|
|
||||||
eq_1 = SystemEquation([eq1, eq2], reduction='none')
|
eq_1 = SystemEquation([eq1, eq2], reduction=None)
|
||||||
|
res = eq_1.residual(pts, u)
|
||||||
|
assert res.shape == torch.Size([10, 3])
|
||||||
|
|
||||||
|
eq_1 = SystemEquation([eq1, eq2])
|
||||||
res = eq_1.residual(pts, u)
|
res = eq_1.residual(pts, u)
|
||||||
assert res.shape == torch.Size([10, 3])
|
assert res.shape == torch.Size([10, 3])
|
||||||
|
|||||||
Reference in New Issue
Block a user