pass method argument to fast laplacian (#648)
This commit is contained in:
@@ -221,6 +221,7 @@ def fast_laplacian(output_, input_, components, d, method="std"):
|
|||||||
divergence of the gradient. Default is ``std``.
|
divergence of the gradient. Default is ``std``.
|
||||||
:return: The computed laplacian tensor.
|
:return: The computed laplacian tensor.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
|
:raises ValueError: If the passed method is neither ``std`` nor ``divgrad``.
|
||||||
"""
|
"""
|
||||||
# Scalar laplacian
|
# Scalar laplacian
|
||||||
if output_.shape[-1] == 1:
|
if output_.shape[-1] == 1:
|
||||||
@@ -415,8 +416,13 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
|
|||||||
components, d = _check_values(
|
components, d = _check_values(
|
||||||
output_=output_, input_=input_, components=components, d=d
|
output_=output_, input_=input_, components=components, d=d
|
||||||
)
|
)
|
||||||
|
|
||||||
return fast_laplacian(
|
return fast_laplacian(
|
||||||
output_=output_, input_=input_, components=components, d=d
|
output_=output_,
|
||||||
|
input_=input_,
|
||||||
|
components=components,
|
||||||
|
d=d,
|
||||||
|
method=method,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -253,7 +253,8 @@ def test_divergence(f):
|
|||||||
Function(),
|
Function(),
|
||||||
ids=["scalar_scalar", "scalar_vector", "vector_scalar", "vector_vector"],
|
ids=["scalar_scalar", "scalar_vector", "vector_scalar", "vector_vector"],
|
||||||
)
|
)
|
||||||
def test_laplacian(f):
|
@pytest.mark.parametrize("method", ["std", "divgrad"])
|
||||||
|
def test_laplacian(f, method):
|
||||||
|
|
||||||
# Unpack the function
|
# Unpack the function
|
||||||
func_input, func, _, _, func_lap = f
|
func_input, func, _, _, func_lap = f
|
||||||
@@ -265,7 +266,7 @@ def test_laplacian(f):
|
|||||||
output_ = LabelTensor(output_, labels)
|
output_ = LabelTensor(output_, labels)
|
||||||
|
|
||||||
# Compute the true laplacian and the pina laplacian
|
# Compute the true laplacian and the pina laplacian
|
||||||
pina_lap = laplacian(output_=output_, input_=input_)
|
pina_lap = laplacian(output_=output_, input_=input_, method=method)
|
||||||
true_lap = func_lap(input_)
|
true_lap = func_lap(input_)
|
||||||
|
|
||||||
# Check the shape and labels of the laplacian
|
# Check the shape and labels of the laplacian
|
||||||
@@ -276,24 +277,34 @@ def test_laplacian(f):
|
|||||||
assert torch.allclose(pina_lap, true_lap)
|
assert torch.allclose(pina_lap, true_lap)
|
||||||
|
|
||||||
# Test if labels are handled correctly
|
# Test if labels are handled correctly
|
||||||
laplacian(output_=output_, input_=input_, components=output_.labels[0])
|
laplacian(
|
||||||
laplacian(output_=output_, input_=input_, d=input_.labels[0])
|
output_=output_,
|
||||||
|
input_=input_,
|
||||||
|
components=output_.labels[0],
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
laplacian(output_=output_, input_=input_, d=input_.labels[0], method=method)
|
||||||
|
|
||||||
# Should fail if input not a LabelTensor
|
# Should fail if input not a LabelTensor
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
laplacian(output_=output_, input_=input_.tensor)
|
laplacian(output_=output_, input_=input_.tensor, method=method)
|
||||||
|
|
||||||
# Should fail if output not a LabelTensor
|
# Should fail if output not a LabelTensor
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
laplacian(output_=output_.tensor, input_=input_)
|
laplacian(output_=output_.tensor, input_=input_, method=method)
|
||||||
|
|
||||||
# Should fail for non-existent input labels
|
# Should fail for non-existent input labels
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
laplacian(output_=output_, input_=input_, d=["x", "y"])
|
laplacian(output_=output_, input_=input_, d=["x", "y"], method=method)
|
||||||
|
|
||||||
# Should fail for non-existent output labels
|
# Should fail for non-existent output labels
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
laplacian(output_=output_, input_=input_, components=["a", "b", "c"])
|
laplacian(
|
||||||
|
output_=output_,
|
||||||
|
input_=input_,
|
||||||
|
components=["a", "b", "c"],
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_advection_scalar():
|
def test_advection_scalar():
|
||||||
|
|||||||
Reference in New Issue
Block a user