support built-in equations in system
This commit is contained in:
@@ -8,18 +8,51 @@ from ..utils import check_consistency
|
|||||||
|
|
||||||
class SystemEquation(EquationInterface):
|
class SystemEquation(EquationInterface):
|
||||||
"""
|
"""
|
||||||
Implementation of the System of Equations. Every ``equation`` passed to a
|
Implementation of the System of Equations, to be passed to a
|
||||||
:class:`~pina.condition.condition.Condition` object must be either a
|
:class:`~pina.condition.condition.Condition` object.
|
||||||
:class:`~pina.equation.equation.Equation` or a
|
|
||||||
:class:`~pina.equation.system_equation.SystemEquation` instance.
|
Unlike the :class:`~pina.equation.equation.Equation` class, which represents
|
||||||
|
a single equation, the :class:`SystemEquation` class allows multiple
|
||||||
|
equations to be grouped together into a system. This is particularly useful
|
||||||
|
when dealing with multi-component outputs or coupled physical models, where
|
||||||
|
the residual must be computed collectively across several constraints.
|
||||||
|
|
||||||
|
Each equation in the system must be either:
|
||||||
|
- An instance of :class:`~pina.equation.equation.Equation`;
|
||||||
|
- A callable function.
|
||||||
|
|
||||||
|
The residuals from each equation are computed independently and then
|
||||||
|
aggregated using an optional reduction strategy (e.g., ``mean``, ``sum``).
|
||||||
|
The resulting residual is returned as a single :class:`~pina.LabelTensor`.
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
|
||||||
|
>>> from pina.equation import SystemEquation, FixedValue, FixedGradient
|
||||||
|
>>> from pina import LabelTensor
|
||||||
|
>>> import torch
|
||||||
|
>>> pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
|
||||||
|
>>> pts.requires_grad = True
|
||||||
|
>>> output_ = torch.pow(pts, 2)
|
||||||
|
>>> output_.labels = ["u", "v"]
|
||||||
|
>>> system_equation = SystemEquation(
|
||||||
|
... [
|
||||||
|
... FixedValue(value=1.0, components=["u"]),
|
||||||
|
... FixedGradient(value=0.0, components=["v"],d=["y"]),
|
||||||
|
... ],
|
||||||
|
... reduction="mean",
|
||||||
|
... )
|
||||||
|
>>> residual = system_equation.residual(pts, output_)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, list_equation, reduction=None):
|
def __init__(self, list_equation, reduction=None):
|
||||||
"""
|
"""
|
||||||
Initialization of the :class:`SystemEquation` class.
|
Initialization of the :class:`SystemEquation` class.
|
||||||
|
|
||||||
:param Callable equation: A ``torch`` callable function used to compute
|
:param list_equation: A list containing either callable functions or
|
||||||
the residual of a mathematical equation.
|
instances of :class:`~pina.equation.equation.Equation`, used to
|
||||||
|
compute the residuals of mathematical equations.
|
||||||
|
:type list_equation: list[Callable] | list[Equation]
|
||||||
:param str reduction: The reduction method to aggregate the residuals of
|
:param str reduction: The reduction method to aggregate the residuals of
|
||||||
each equation. Available options are: ``None``, ``mean``, ``sum``,
|
each equation. Available options are: ``None``, ``mean``, ``sum``,
|
||||||
``callable``.
|
``callable``.
|
||||||
@@ -32,9 +65,10 @@ class SystemEquation(EquationInterface):
|
|||||||
check_consistency([list_equation], list)
|
check_consistency([list_equation], list)
|
||||||
|
|
||||||
# equations definition
|
# equations definition
|
||||||
self.equations = []
|
self.equations = [
|
||||||
for _, equation in enumerate(list_equation):
|
equation if isinstance(equation, Equation) else Equation(equation)
|
||||||
self.equations.append(Equation(equation))
|
for equation in list_equation
|
||||||
|
]
|
||||||
|
|
||||||
# possible reduction
|
# possible reduction
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
@@ -45,7 +79,7 @@ class SystemEquation(EquationInterface):
|
|||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Only mean and sum reductions implemented."
|
"Only mean and sum reductions are currenly supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
def residual(self, input_, output_, params_=None):
|
def residual(self, input_, output_, params_=None):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from pina.equation import SystemEquation
|
from pina.equation import SystemEquation, FixedValue, FixedGradient
|
||||||
from pina.operator import grad, laplacian
|
from pina.operator import grad, laplacian
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
import torch
|
import torch
|
||||||
@@ -24,34 +24,78 @@ def foo():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_constructor():
|
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
|
||||||
SystemEquation([eq1, eq2])
|
def test_constructor(reduction):
|
||||||
SystemEquation([eq1, eq2], reduction="sum")
|
|
||||||
|
# Constructor with callable functions
|
||||||
|
SystemEquation([eq1, eq2], reduction=reduction)
|
||||||
|
|
||||||
|
# Constructor with Equation instances
|
||||||
|
SystemEquation(
|
||||||
|
[
|
||||||
|
FixedValue(value=0.0, components=["u1"]),
|
||||||
|
FixedGradient(value=0.0, components=["u2"]),
|
||||||
|
],
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Constructor with mixed types
|
||||||
|
SystemEquation(
|
||||||
|
[
|
||||||
|
FixedValue(value=0.0, components=["u1"]),
|
||||||
|
eq1,
|
||||||
|
],
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-standard reduction not implemented
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
SystemEquation([eq1, eq2], reduction="foo")
|
SystemEquation([eq1, eq2], reduction="foo")
|
||||||
|
|
||||||
|
# Invalid input type
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
SystemEquation(foo)
|
SystemEquation(foo)
|
||||||
|
|
||||||
|
|
||||||
def test_residual():
|
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
|
||||||
|
def test_residual(reduction):
|
||||||
|
|
||||||
|
# Generate random points and output
|
||||||
pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
|
pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
|
||||||
pts.requires_grad = True
|
pts.requires_grad = True
|
||||||
u = torch.pow(pts, 2)
|
u = torch.pow(pts, 2)
|
||||||
u.labels = ["u1", "u2"]
|
u.labels = ["u1", "u2"]
|
||||||
|
|
||||||
eq_1 = SystemEquation([eq1, eq2], reduction="mean")
|
# System with callable functions
|
||||||
res = eq_1.residual(pts, u)
|
system_eq = SystemEquation([eq1, eq2], reduction=reduction)
|
||||||
assert res.shape == torch.Size([10])
|
res = system_eq.residual(pts, u)
|
||||||
|
|
||||||
eq_1 = SystemEquation([eq1, eq2], reduction="sum")
|
# Checks on the shape of the residual
|
||||||
res = eq_1.residual(pts, u)
|
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
|
||||||
assert res.shape == torch.Size([10])
|
assert res.shape == shape
|
||||||
|
|
||||||
eq_1 = SystemEquation([eq1, eq2], reduction=None)
|
# System with Equation instances
|
||||||
res = eq_1.residual(pts, u)
|
system_eq = SystemEquation(
|
||||||
assert res.shape == torch.Size([10, 3])
|
[
|
||||||
|
FixedValue(value=0.0, components=["u1"]),
|
||||||
|
FixedGradient(value=0.0, components=["u2"]),
|
||||||
|
],
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
eq_1 = SystemEquation([eq1, eq2])
|
# Checks on the shape of the residual
|
||||||
res = eq_1.residual(pts, u)
|
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
|
||||||
assert res.shape == torch.Size([10, 3])
|
assert res.shape == shape
|
||||||
|
|
||||||
|
# System with mixed types
|
||||||
|
system_eq = SystemEquation(
|
||||||
|
[
|
||||||
|
FixedValue(value=0.0, components=["u1"]),
|
||||||
|
eq1,
|
||||||
|
],
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checks on the shape of the residual
|
||||||
|
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
|
||||||
|
assert res.shape == shape
|
||||||
|
|||||||
Reference in New Issue
Block a user