Fixing Laplacian operator for vector fields (#380)

* fix laplacian and tests
This commit is contained in:
Giovanni Canali
2024-11-18 17:25:34 +01:00
committed by GitHub
parent db521ef468
commit a78f44ecef
2 changed files with 50 additions and 26 deletions

View File

@@ -52,15 +52,26 @@ def test_div_vector_output():
def test_laplacian_scalar_output():
laplace_tensor_v = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
assert laplace_tensor_v.shape == tensor_s.shape
laplace_tensor_s = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
assert laplace_tensor_s.shape == tensor_s.shape
assert laplace_tensor_s.labels == [f"dd{tensor_s.labels[0]}"]
true_val = 4*torch.ones_like(laplace_tensor_s)
assert all((laplace_tensor_s - true_val == 0).flatten())
def test_laplacian_vector_output():
laplace_tensor_v = laplacian(tensor_v, inp)
assert laplace_tensor_v.shape == tensor_v.shape
assert laplace_tensor_v.labels == [
f'dd{i}' for i in tensor_v.labels
]
laplace_tensor_v = laplacian(tensor_v,
inp,
components=['a', 'b'],
d=['x', 'y'])
assert laplace_tensor_v.shape == tensor_v.extract(['a', 'b']).shape
assert laplace_tensor_v.labels == [
f'dd{i}' for i in ['a', 'b']
]
true_val = 2*torch.ones_like(tensor_v.extract(['a', 'b']))
assert all((laplace_tensor_v - true_val == 0).flatten())