fix bugs for helmholtz and advection (#686)
This commit is contained in:
@@ -239,19 +239,19 @@ class Advection(Equation): # pylint: disable=R0903
|
||||
)
|
||||
|
||||
# Ensure consistency of c length
|
||||
if len(self.c) != (len(input_lbl) - 1) and len(self.c) > 1:
|
||||
if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1:
|
||||
raise ValueError(
|
||||
"If 'c' is passed as a list, its length must be equal to "
|
||||
"the number of spatial dimensions."
|
||||
)
|
||||
|
||||
# Repeat c to ensure consistent shape for advection
|
||||
self.c = self.c.repeat(output_.shape[0], 1)
|
||||
if self.c.shape[1] != (len(input_lbl) - 1):
|
||||
self.c = self.c.repeat(1, len(input_lbl) - 1)
|
||||
c = self.c.repeat(output_.shape[0], 1)
|
||||
if c.shape[1] != (len(input_lbl) - 1):
|
||||
c = c.repeat(1, len(input_lbl) - 1)
|
||||
|
||||
# Add a dimension to c for the following operations
|
||||
self.c = self.c.unsqueeze(-1)
|
||||
c = c.unsqueeze(-1)
|
||||
|
||||
# Compute the time derivative and the spatial gradient
|
||||
time_der = grad(output_, input_, components=None, d="t")
|
||||
@@ -262,7 +262,7 @@ class Advection(Equation): # pylint: disable=R0903
|
||||
tmp = tmp.transpose(-1, -2)
|
||||
|
||||
# Compute advection term
|
||||
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2)
|
||||
adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
|
||||
|
||||
return time_der + adv
|
||||
|
||||
|
||||
@@ -48,11 +48,10 @@ class HelmholtzProblem(SpatialProblem):
|
||||
:type alpha: float | int
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.alpha = alpha
|
||||
check_consistency(alpha, (int, float))
|
||||
self.alpha = alpha
|
||||
|
||||
def forcing_term(self, input_):
|
||||
def forcing_term(input_):
|
||||
"""
|
||||
Implementation of the forcing term.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user