pass method argument to fast laplacian (#648)

This commit is contained in:
Giovanni Canali
2025-09-18 13:07:49 +02:00
committed by GitHub
parent dc808c1d77
commit 87c5c6a674
2 changed files with 26 additions and 9 deletions

View File

@@ -221,6 +221,7 @@ def fast_laplacian(output_, input_, components, d, method="std"):
divergence of the gradient. Default is ``std``. divergence of the gradient. Default is ``std``.
:return: The computed laplacian tensor. :return: The computed laplacian tensor.
:rtype: LabelTensor :rtype: LabelTensor
:raises ValueError: If the passed method is neither ``std`` nor ``divgrad``.
""" """
# Scalar laplacian # Scalar laplacian
if output_.shape[-1] == 1: if output_.shape[-1] == 1:
@@ -415,8 +416,13 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
components, d = _check_values( components, d = _check_values(
output_=output_, input_=input_, components=components, d=d output_=output_, input_=input_, components=components, d=d
) )
return fast_laplacian( return fast_laplacian(
output_=output_, input_=input_, components=components, d=d output_=output_,
input_=input_,
components=components,
d=d,
method=method,
) )

View File

@@ -253,7 +253,8 @@ def test_divergence(f):
Function(), Function(),
ids=["scalar_scalar", "scalar_vector", "vector_scalar", "vector_vector"], ids=["scalar_scalar", "scalar_vector", "vector_scalar", "vector_vector"],
) )
def test_laplacian(f): @pytest.mark.parametrize("method", ["std", "divgrad"])
def test_laplacian(f, method):
# Unpack the function # Unpack the function
func_input, func, _, _, func_lap = f func_input, func, _, _, func_lap = f
@@ -265,7 +266,7 @@ def test_laplacian(f):
output_ = LabelTensor(output_, labels) output_ = LabelTensor(output_, labels)
# Compute the true laplacian and the pina laplacian # Compute the true laplacian and the pina laplacian
pina_lap = laplacian(output_=output_, input_=input_) pina_lap = laplacian(output_=output_, input_=input_, method=method)
true_lap = func_lap(input_) true_lap = func_lap(input_)
# Check the shape and labels of the laplacian # Check the shape and labels of the laplacian
@@ -276,24 +277,34 @@ def test_laplacian(f):
assert torch.allclose(pina_lap, true_lap) assert torch.allclose(pina_lap, true_lap)
# Test if labels are handled correctly # Test if labels are handled correctly
laplacian(output_=output_, input_=input_, components=output_.labels[0]) laplacian(
laplacian(output_=output_, input_=input_, d=input_.labels[0]) output_=output_,
input_=input_,
components=output_.labels[0],
method=method,
)
laplacian(output_=output_, input_=input_, d=input_.labels[0], method=method)
# Should fail if input not a LabelTensor # Should fail if input not a LabelTensor
with pytest.raises(TypeError): with pytest.raises(TypeError):
laplacian(output_=output_, input_=input_.tensor) laplacian(output_=output_, input_=input_.tensor, method=method)
# Should fail if output not a LabelTensor # Should fail if output not a LabelTensor
with pytest.raises(TypeError): with pytest.raises(TypeError):
laplacian(output_=output_.tensor, input_=input_) laplacian(output_=output_.tensor, input_=input_, method=method)
# Should fail for non-existent input labels # Should fail for non-existent input labels
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
laplacian(output_=output_, input_=input_, d=["x", "y"]) laplacian(output_=output_, input_=input_, d=["x", "y"], method=method)
# Should fail for non-existent output labels # Should fail for non-existent output labels
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
laplacian(output_=output_, input_=input_, components=["a", "b", "c"]) laplacian(
output_=output_,
input_=input_,
components=["a", "b", "c"],
method=method,
)
def test_advection_scalar(): def test_advection_scalar():