fix bugs for helmholtz and advection (#686)

This commit is contained in:
Giovanni Canali
2025-10-30 10:15:38 +01:00
committed by GitHub
parent 64930c431f
commit fca3db7926
3 changed files with 9 additions and 10 deletions

View File

@@ -239,19 +239,19 @@ class Advection(Equation): # pylint: disable=R0903
) )
# Ensure consistency of c length # Ensure consistency of c length
if len(self.c) != (len(input_lbl) - 1) and len(self.c) > 1: if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1:
raise ValueError( raise ValueError(
"If 'c' is passed as a list, its length must be equal to " "If 'c' is passed as a list, its length must be equal to "
"the number of spatial dimensions." "the number of spatial dimensions."
) )
# Repeat c to ensure consistent shape for advection # Repeat c to ensure consistent shape for advection
self.c = self.c.repeat(output_.shape[0], 1) c = self.c.repeat(output_.shape[0], 1)
if self.c.shape[1] != (len(input_lbl) - 1): if c.shape[1] != (len(input_lbl) - 1):
self.c = self.c.repeat(1, len(input_lbl) - 1) c = c.repeat(1, len(input_lbl) - 1)
# Add a dimension to c for the following operations # Add a dimension to c for the following operations
self.c = self.c.unsqueeze(-1) c = c.unsqueeze(-1)
# Compute the time derivative and the spatial gradient # Compute the time derivative and the spatial gradient
time_der = grad(output_, input_, components=None, d="t") time_der = grad(output_, input_, components=None, d="t")
@@ -262,7 +262,7 @@ class Advection(Equation): # pylint: disable=R0903
tmp = tmp.transpose(-1, -2) tmp = tmp.transpose(-1, -2)
# Compute advection term # Compute advection term
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2) adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
return time_der + adv return time_der + adv

View File

@@ -48,11 +48,10 @@ class HelmholtzProblem(SpatialProblem):
:type alpha: float | int :type alpha: float | int
""" """
super().__init__() super().__init__()
self.alpha = alpha
check_consistency(alpha, (int, float)) check_consistency(alpha, (int, float))
self.alpha = alpha
def forcing_term(self, input_): def forcing_term(input_):
""" """
Implementation of the forcing term. Implementation of the forcing term.
""" """

View File

@@ -104,7 +104,7 @@ def test_advection_equation(c):
# Should fail if c is a list and its length != spatial dimension # Should fail if c is a list and its length != spatial dimension
with pytest.raises(ValueError): with pytest.raises(ValueError):
Advection([1, 2, 3]) equation = Advection([1, 2, 3])
residual = equation.residual(pts, u) residual = equation.residual(pts, u)