fix bugs for helmholtz and advection (#686)

This commit is contained in:
Giovanni Canali
2025-10-30 10:15:38 +01:00
committed by GitHub
parent 64930c431f
commit fca3db7926
3 changed files with 9 additions and 10 deletions

View File

@@ -239,19 +239,19 @@ class Advection(Equation): # pylint: disable=R0903
)
# Ensure consistency of c length
if len(self.c) != (len(input_lbl) - 1) and len(self.c) > 1:
if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1:
raise ValueError(
"If 'c' is passed as a list, its length must be equal to "
"the number of spatial dimensions."
)
# Repeat c to ensure consistent shape for advection
self.c = self.c.repeat(output_.shape[0], 1)
if self.c.shape[1] != (len(input_lbl) - 1):
self.c = self.c.repeat(1, len(input_lbl) - 1)
c = self.c.repeat(output_.shape[0], 1)
if c.shape[1] != (len(input_lbl) - 1):
c = c.repeat(1, len(input_lbl) - 1)
# Add a dimension to c for the following operations
self.c = self.c.unsqueeze(-1)
c = c.unsqueeze(-1)
# Compute the time derivative and the spatial gradient
time_der = grad(output_, input_, components=None, d="t")
@@ -262,7 +262,7 @@ class Advection(Equation): # pylint: disable=R0903
tmp = tmp.transpose(-1, -2)
# Compute advection term
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2)
adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
return time_der + adv

View File

@@ -48,11 +48,10 @@ class HelmholtzProblem(SpatialProblem):
:type alpha: float | int
"""
super().__init__()
self.alpha = alpha
check_consistency(alpha, (int, float))
self.alpha = alpha
def forcing_term(self, input_):
def forcing_term(input_):
"""
Implementation of the forcing term.
"""