From 2cb0eadac16313b7a02520863d36a3ab94bb5902 Mon Sep 17 00:00:00 2001 From: giovanni Date: Tue, 1 Jul 2025 12:37:33 +0200 Subject: [PATCH] support built-in equations in system --- pina/equation/system_equation.py | 54 +++++++++++--- tests/test_equations/test_system_equation.py | 78 +++++++++++++++----- 2 files changed, 105 insertions(+), 27 deletions(-) diff --git a/pina/equation/system_equation.py b/pina/equation/system_equation.py index d51ba94..21cb271 100644 --- a/pina/equation/system_equation.py +++ b/pina/equation/system_equation.py @@ -8,18 +8,51 @@ from ..utils import check_consistency class SystemEquation(EquationInterface): """ - Implementation of the System of Equations. Every ``equation`` passed to a - :class:`~pina.condition.condition.Condition` object must be either a - :class:`~pina.equation.equation.Equation` or a - :class:`~pina.equation.system_equation.SystemEquation` instance. + Implementation of the System of Equations, to be passed to a + :class:`~pina.condition.condition.Condition` object. + + Unlike the :class:`~pina.equation.equation.Equation` class, which represents + a single equation, the :class:`SystemEquation` class allows multiple + equations to be grouped together into a system. This is particularly useful + when dealing with multi-component outputs or coupled physical models, where + the residual must be computed collectively across several constraints. + + Each equation in the system must be either: + - An instance of :class:`~pina.equation.equation.Equation`; + - A callable function. + + The residuals from each equation are computed independently and then + aggregated using an optional reduction strategy (e.g., ``mean``, ``sum``). + The resulting residual is returned as a single :class:`~pina.LabelTensor`. + + :Example: + + >>> from pina.equation import SystemEquation, FixedValue, FixedGradient + >>> from pina import LabelTensor + >>> import torch + >>> pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"]) + >>> pts.requires_grad = True + >>> output_ = torch.pow(pts, 2) + >>> output_.labels = ["u", "v"] + >>> system_equation = SystemEquation( + ... [ + ... FixedValue(value=1.0, components=["u"]), + ... FixedGradient(value=0.0, components=["v"],d=["y"]), + ... ], + ... reduction="mean", + ... ) + >>> residual = system_equation.residual(pts, output_) + """ def __init__(self, list_equation, reduction=None): """ Initialization of the :class:`SystemEquation` class. - :param Callable equation: A ``torch`` callable function used to compute - the residual of a mathematical equation. + :param list_equation: A list containing either callable functions or + instances of :class:`~pina.equation.equation.Equation`, used to + compute the residuals of mathematical equations. + :type list_equation: list[Callable] | list[Equation] :param str reduction: The reduction method to aggregate the residuals of each equation. Available options are: ``None``, ``mean``, ``sum``, ``callable``. @@ -32,9 +65,10 @@ class SystemEquation(EquationInterface): check_consistency([list_equation], list) # equations definition - self.equations = [] - for _, equation in enumerate(list_equation): - self.equations.append(Equation(equation)) + self.equations = [ + equation if isinstance(equation, Equation) else Equation(equation) + for equation in list_equation + ] # possible reduction if reduction == "mean": @@ -45,7 +79,7 @@ class SystemEquation(EquationInterface): self.reduction = reduction else: raise NotImplementedError( - "Only mean and sum reductions implemented." + "Only mean and sum reductions are currenly supported." ) def residual(self, input_, output_, params_=None): diff --git a/tests/test_equations/test_system_equation.py b/tests/test_equations/test_system_equation.py index 4a0a116..bf62681 100644 --- a/tests/test_equations/test_system_equation.py +++ b/tests/test_equations/test_system_equation.py @@ -1,4 +1,4 @@ -from pina.equation import SystemEquation +from pina.equation import SystemEquation, FixedValue, FixedGradient from pina.operator import grad, laplacian from pina import LabelTensor import torch @@ -24,34 +24,78 @@ def foo(): pass -def test_constructor(): - SystemEquation([eq1, eq2]) - SystemEquation([eq1, eq2], reduction="sum") +@pytest.mark.parametrize("reduction", [None, "mean", "sum"]) +def test_constructor(reduction): + + # Constructor with callable functions + SystemEquation([eq1, eq2], reduction=reduction) + + # Constructor with Equation instances + SystemEquation( + [ + FixedValue(value=0.0, components=["u1"]), + FixedGradient(value=0.0, components=["u2"]), + ], + reduction=reduction, + ) + + # Constructor with mixed types + SystemEquation( + [ + FixedValue(value=0.0, components=["u1"]), + eq1, + ], + reduction=reduction, + ) + + # Non-standard reduction not implemented with pytest.raises(NotImplementedError): SystemEquation([eq1, eq2], reduction="foo") + + # Invalid input type with pytest.raises(ValueError): SystemEquation(foo) -def test_residual(): +@pytest.mark.parametrize("reduction", [None, "mean", "sum"]) +def test_residual(reduction): + # Generate random points and output pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"]) pts.requires_grad = True u = torch.pow(pts, 2) u.labels = ["u1", "u2"] - eq_1 = SystemEquation([eq1, eq2], reduction="mean") - res = eq_1.residual(pts, u) - assert res.shape == torch.Size([10]) + # System with callable functions + system_eq = SystemEquation([eq1, eq2], reduction=reduction) + res = system_eq.residual(pts, u) - eq_1 = SystemEquation([eq1, eq2], reduction="sum") - res = eq_1.residual(pts, u) - assert res.shape == torch.Size([10]) + # Checks on the shape of the residual + shape = torch.Size([10, 3]) if reduction is None else torch.Size([10]) + assert res.shape == shape - eq_1 = SystemEquation([eq1, eq2], reduction=None) - res = eq_1.residual(pts, u) - assert res.shape == torch.Size([10, 3]) + # System with Equation instances + system_eq = SystemEquation( + [ + FixedValue(value=0.0, components=["u1"]), + FixedGradient(value=0.0, components=["u2"]), + ], + reduction=reduction, + ) - eq_1 = SystemEquation([eq1, eq2]) - res = eq_1.residual(pts, u) - assert res.shape == torch.Size([10, 3]) + # Checks on the shape of the residual + shape = torch.Size([10, 3]) if reduction is None else torch.Size([10]) + assert res.shape == shape + + # System with mixed types + system_eq = SystemEquation( + [ + FixedValue(value=0.0, components=["u1"]), + eq1, + ], + reduction=reduction, + ) + + # Checks on the shape of the residual + shape = torch.Size([10, 3]) if reduction is None else torch.Size([10]) + assert res.shape == shape