Fixing Laplacian operator for vector fields (#380)

* fix laplacian and tests
This commit is contained in:
Giovanni Canali
2024-11-18 17:25:34 +01:00
committed by GitHub
parent db521ef468
commit a78f44ecef
2 changed files with 50 additions and 26 deletions

View File

@@ -170,53 +170,66 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
is calculated. d should be a subset of the input labels. If None, all
the input variables are considered. Default is None.
:param str method: used method to calculate Laplacian, defaults to 'std'.
:raises ValueError: for vectorial field derivative with respect to
all coordinates must be performed.
:raises NotImplementedError: 'divgrad' not implemented as method.
:return: The tensor containing the result of the Laplacian operator.
:rtype: LabelTensor
"""
def scalar_laplace(output_, input_, components, d):
"""
Compute Laplace operator for a scalar output.
:param LabelTensor output_: the output tensor onto which computing the
Laplacian. It has to be a column tensor.
:param LabelTensor input_: the input tensor with respect to which
computing the Laplacian.
:param list(str) components: the name of the output variables to
calculate the Laplacian for. It should be a subset of the output
labels. If None, all the output variables are considered.
:param list(str) d: the name of the input variables on which the
Laplacian is computed. d should be a subset of the input labels.
If None, all the input variables are considered. Default is None.
:return: The tensor containing the result of the Laplacian operator.
:rtype: LabelTensor
"""
grad_output = grad(output_, input_, components=components, d=d)
result = torch.zeros(output_.shape[0], 1, device=output_.device)
for i, label in enumerate(grad_output.labels):
gg = grad(grad_output, input_, d=d, components=[label])
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(i)
return result
if d is None:
d = input_.labels
if components is None:
components = output_.labels
if len(components) != len(d) and len(components) != 1:
raise ValueError
if method == "divgrad":
raise NotImplementedError("divgrad not implemented as method")
# TODO fix
# grad_output = grad(output_, input_, components, d)
# result = div(grad_output, input_, d=d)
elif method == "std":
elif method == "std":
if len(components) == 1:
grad_output = grad(output_, input_, components=components, d=d)
result = torch.zeros(output_.shape[0], 1, device=output_.device)
for i, label in enumerate(grad_output.labels):
gg = grad(grad_output, input_, d=d, components=[label])
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(
i
) # TODO improve
result = scalar_laplace(output_, input_, components, d)
labels = [f"dd{components[0]}"]
else:
result = torch.empty(
input_.shape[0], len(components), device=output_.device
size=(input_.shape[0], len(components)),
dtype=output_.dtype, device=output_.device
)
labels = [None] * len(components)
for idx, (ci, di) in enumerate(zip(components, d)):
if not isinstance(ci, list):
ci = [ci]
if not isinstance(di, list):
di = [di]
grad_output = grad(output_, input_, components=ci, d=di)
result[:, idx] = grad(grad_output, input_, d=di).flatten()
labels[idx] = f"dd{ci}dd{di}"
for idx, c in enumerate(components):
result[:, idx] = scalar_laplace(output_, input_, c, d).flatten()
labels[idx] = f"dd{c}"
result = result.as_subclass(LabelTensor)
result.labels = labels

View File

@@ -52,15 +52,26 @@ def test_div_vector_output():
def test_laplacian_scalar_output():
laplace_tensor_v = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
assert laplace_tensor_v.shape == tensor_s.shape
laplace_tensor_s = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
assert laplace_tensor_s.shape == tensor_s.shape
assert laplace_tensor_s.labels == [f"dd{tensor_s.labels[0]}"]
true_val = 4*torch.ones_like(laplace_tensor_s)
assert all((laplace_tensor_s - true_val == 0).flatten())
def test_laplacian_vector_output():
laplace_tensor_v = laplacian(tensor_v, inp)
assert laplace_tensor_v.shape == tensor_v.shape
assert laplace_tensor_v.labels == [
f'dd{i}' for i in tensor_v.labels
]
laplace_tensor_v = laplacian(tensor_v,
inp,
components=['a', 'b'],
d=['x', 'y'])
assert laplace_tensor_v.shape == tensor_v.extract(['a', 'b']).shape
assert laplace_tensor_v.labels == [
f'dd{i}' for i in ['a', 'b']
]
true_val = 2*torch.ones_like(tensor_v.extract(['a', 'b']))
assert all((laplace_tensor_v - true_val == 0).flatten())