support built-in equations in system

This commit is contained in:
giovanni
2025-07-01 12:37:33 +02:00
committed by Dario Coscia
parent de47d69fec
commit 2cb0eadac1
2 changed files with 105 additions and 27 deletions

View File

@@ -8,18 +8,51 @@ from ..utils import check_consistency
class SystemEquation(EquationInterface): class SystemEquation(EquationInterface):
""" """
Implementation of the System of Equations. Every ``equation`` passed to a Implementation of the System of Equations, to be passed to a
:class:`~pina.condition.condition.Condition` object must be either a :class:`~pina.condition.condition.Condition` object.
:class:`~pina.equation.equation.Equation` or a
:class:`~pina.equation.system_equation.SystemEquation` instance. Unlike the :class:`~pina.equation.equation.Equation` class, which represents
a single equation, the :class:`SystemEquation` class allows multiple
equations to be grouped together into a system. This is particularly useful
when dealing with multi-component outputs or coupled physical models, where
the residual must be computed collectively across several constraints.
Each equation in the system must be either:
- An instance of :class:`~pina.equation.equation.Equation`;
- A callable function.
The residuals from each equation are computed independently and then
aggregated using an optional reduction strategy (e.g., ``mean``, ``sum``).
The resulting residual is returned as a single :class:`~pina.LabelTensor`.
:Example:
>>> from pina.equation import SystemEquation, FixedValue, FixedGradient
>>> from pina import LabelTensor
>>> import torch
>>> pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
>>> pts.requires_grad = True
>>> output_ = torch.pow(pts, 2)
>>> output_.labels = ["u", "v"]
>>> system_equation = SystemEquation(
... [
... FixedValue(value=1.0, components=["u"]),
... FixedGradient(value=0.0, components=["v"],d=["y"]),
... ],
... reduction="mean",
... )
>>> residual = system_equation.residual(pts, output_)
""" """
def __init__(self, list_equation, reduction=None): def __init__(self, list_equation, reduction=None):
""" """
Initialization of the :class:`SystemEquation` class. Initialization of the :class:`SystemEquation` class.
:param Callable equation: A ``torch`` callable function used to compute :param list_equation: A list containing either callable functions or
the residual of a mathematical equation. instances of :class:`~pina.equation.equation.Equation`, used to
compute the residuals of mathematical equations.
:type list_equation: list[Callable] | list[Equation]
:param str reduction: The reduction method to aggregate the residuals of :param str reduction: The reduction method to aggregate the residuals of
each equation. Available options are: ``None``, ``mean``, ``sum``, each equation. Available options are: ``None``, ``mean``, ``sum``,
``callable``. ``callable``.
@@ -32,9 +65,10 @@ class SystemEquation(EquationInterface):
check_consistency([list_equation], list) check_consistency([list_equation], list)
# equations definition # equations definition
self.equations = [] self.equations = [
for _, equation in enumerate(list_equation): equation if isinstance(equation, Equation) else Equation(equation)
self.equations.append(Equation(equation)) for equation in list_equation
]
# possible reduction # possible reduction
if reduction == "mean": if reduction == "mean":
@@ -45,7 +79,7 @@ class SystemEquation(EquationInterface):
self.reduction = reduction self.reduction = reduction
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only mean and sum reductions implemented." "Only mean and sum reductions are currenly supported."
) )
def residual(self, input_, output_, params_=None): def residual(self, input_, output_, params_=None):

View File

@@ -1,4 +1,4 @@
from pina.equation import SystemEquation from pina.equation import SystemEquation, FixedValue, FixedGradient
from pina.operator import grad, laplacian from pina.operator import grad, laplacian
from pina import LabelTensor from pina import LabelTensor
import torch import torch
@@ -24,34 +24,78 @@ def foo():
pass pass
def test_constructor(): @pytest.mark.parametrize("reduction", [None, "mean", "sum"])
SystemEquation([eq1, eq2]) def test_constructor(reduction):
SystemEquation([eq1, eq2], reduction="sum")
# Constructor with callable functions
SystemEquation([eq1, eq2], reduction=reduction)
# Constructor with Equation instances
SystemEquation(
[
FixedValue(value=0.0, components=["u1"]),
FixedGradient(value=0.0, components=["u2"]),
],
reduction=reduction,
)
# Constructor with mixed types
SystemEquation(
[
FixedValue(value=0.0, components=["u1"]),
eq1,
],
reduction=reduction,
)
# Non-standard reduction not implemented
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
SystemEquation([eq1, eq2], reduction="foo") SystemEquation([eq1, eq2], reduction="foo")
# Invalid input type
with pytest.raises(ValueError): with pytest.raises(ValueError):
SystemEquation(foo) SystemEquation(foo)
def test_residual(): @pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_residual(reduction):
# Generate random points and output
pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"]) pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
pts.requires_grad = True pts.requires_grad = True
u = torch.pow(pts, 2) u = torch.pow(pts, 2)
u.labels = ["u1", "u2"] u.labels = ["u1", "u2"]
eq_1 = SystemEquation([eq1, eq2], reduction="mean") # System with callable functions
res = eq_1.residual(pts, u) system_eq = SystemEquation([eq1, eq2], reduction=reduction)
assert res.shape == torch.Size([10]) res = system_eq.residual(pts, u)
eq_1 = SystemEquation([eq1, eq2], reduction="sum") # Checks on the shape of the residual
res = eq_1.residual(pts, u) shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
assert res.shape == torch.Size([10]) assert res.shape == shape
eq_1 = SystemEquation([eq1, eq2], reduction=None) # System with Equation instances
res = eq_1.residual(pts, u) system_eq = SystemEquation(
assert res.shape == torch.Size([10, 3]) [
FixedValue(value=0.0, components=["u1"]),
FixedGradient(value=0.0, components=["u2"]),
],
reduction=reduction,
)
eq_1 = SystemEquation([eq1, eq2]) # Checks on the shape of the residual
res = eq_1.residual(pts, u) shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
assert res.shape == torch.Size([10, 3]) assert res.shape == shape
# System with mixed types
system_eq = SystemEquation(
[
FixedValue(value=0.0, components=["u1"]),
eq1,
],
reduction=reduction,
)
# Checks on the shape of the residual
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
assert res.shape == shape