Fixing Laplacian operator for vector fields (#380)

* fix laplacian and tests
This commit is contained in:
Giovanni Canali
2024-11-18 17:25:34 +01:00
committed by GitHub
parent db521ef468
commit a78f44ecef
2 changed files with 50 additions and 26 deletions

View File

@@ -170,53 +170,66 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
is calculated. d should be a subset of the input labels. If None, all
the input variables are considered. Default is None.
:param str method: used method to calculate Laplacian, defaults to 'std'.
:raises ValueError: for vectorial field derivative with respect to
all coordinates must be performed.
:raises NotImplementedError: 'divgrad' not implemented as method.
:return: The tensor containing the result of the Laplacian operator.
:rtype: LabelTensor
"""
def scalar_laplace(output_, input_, components, d):
"""
Compute Laplace operator for a scalar output.
:param LabelTensor output_: the output tensor onto which computing the
Laplacian. It has to be a column tensor.
:param LabelTensor input_: the input tensor with respect to which
computing the Laplacian.
:param list(str) components: the name of the output variables to
calculate the Laplacian for. It should be a subset of the output
labels. If None, all the output variables are considered.
:param list(str) d: the name of the input variables on which the
Laplacian is computed. d should be a subset of the input labels.
If None, all the input variables are considered. Default is None.
:return: The tensor containing the result of the Laplacian operator.
:rtype: LabelTensor
"""
grad_output = grad(output_, input_, components=components, d=d)
result = torch.zeros(output_.shape[0], 1, device=output_.device)
for i, label in enumerate(grad_output.labels):
gg = grad(grad_output, input_, d=d, components=[label])
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(i)
return result
if d is None:
d = input_.labels
if components is None:
components = output_.labels
if len(components) != len(d) and len(components) != 1:
raise ValueError
if method == "divgrad":
raise NotImplementedError("divgrad not implemented as method")
# TODO fix
# grad_output = grad(output_, input_, components, d)
# result = div(grad_output, input_, d=d)
elif method == "std":
elif method == "std":
if len(components) == 1:
grad_output = grad(output_, input_, components=components, d=d)
result = torch.zeros(output_.shape[0], 1, device=output_.device)
for i, label in enumerate(grad_output.labels):
gg = grad(grad_output, input_, d=d, components=[label])
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(
i
) # TODO improve
result = scalar_laplace(output_, input_, components, d)
labels = [f"dd{components[0]}"]
else:
result = torch.empty(
input_.shape[0], len(components), device=output_.device
size=(input_.shape[0], len(components)),
dtype=output_.dtype, device=output_.device
)
labels = [None] * len(components)
for idx, (ci, di) in enumerate(zip(components, d)):
if not isinstance(ci, list):
ci = [ci]
if not isinstance(di, list):
di = [di]
grad_output = grad(output_, input_, components=ci, d=di)
result[:, idx] = grad(grad_output, input_, d=di).flatten()
labels[idx] = f"dd{ci}dd{di}"
for idx, c in enumerate(components):
result[:, idx] = scalar_laplace(output_, input_, c, d).flatten()
labels[idx] = f"dd{c}"
result = result.as_subclass(LabelTensor)
result.labels = labels