Update Laplace class and add unit tests (#645)

This commit is contained in:
Giovanni Canali
2025-09-22 15:05:28 +02:00
committed by GitHub
parent 4a6e73fa54
commit 4e37468460
15 changed files with 673 additions and 157 deletions

View File

@@ -2,46 +2,10 @@
import torch
from ... import Condition
from ...operator import laplacian
from ...equation import FixedValue, Helmholtz
from ...utils import check_consistency
from ...domain import CartesianDomain
from ...problem import SpatialProblem
from ...utils import check_consistency
from ...equation import Equation, FixedValue
class HelmholtzEquation(Equation):
"""
Implementation of the Helmholtz equation.
"""
def __init__(self, alpha):
"""
Initialization of the :class:`HelmholtzEquation` class.
:param alpha: Parameter of the forcing term.
:type alpha: float | int
"""
self.alpha = alpha
check_consistency(alpha, (int, float))
def equation(input_, output_):
"""
Implementation of the Helmholtz equation.
:param LabelTensor input_: Input data of the problem.
:param LabelTensor output_: Output data of the problem.
:return: The residual of the Helmholtz equation.
:rtype: LabelTensor
"""
lap = laplacian(output_, input_, components=["u"], d=["x", "y"])
q = (
(1 - 2 * (self.alpha * torch.pi) ** 2)
* torch.sin(self.alpha * torch.pi * input_.extract("x"))
* torch.sin(self.alpha * torch.pi * input_.extract("y"))
)
return lap + output_ - q
super().__init__(equation)
class HelmholtzProblem(SpatialProblem):
@@ -88,8 +52,19 @@ class HelmholtzProblem(SpatialProblem):
self.alpha = alpha
check_consistency(alpha, (int, float))
def forcing_term(self, input_):
"""
Implementation of the forcing term.
"""
return (
(1 - 2 * (self.alpha * torch.pi) ** 2)
* torch.sin(self.alpha * torch.pi * input_.extract("x"))
* torch.sin(self.alpha * torch.pi * input_.extract("y"))
)
self.conditions["D"] = Condition(
domain="D", equation=HelmholtzEquation(self.alpha)
domain="D",
equation=Helmholtz(self.alpha, forcing_term),
)
def solution(self, pts):