Fixing Laplacian operator for vector fields (#380)
* fix laplacian and tests
This commit is contained in:
@@ -170,53 +170,66 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
|
||||
is calculated. d should be a subset of the input labels. If None, all
|
||||
the input variables are considered. Default is None.
|
||||
:param str method: used method to calculate Laplacian, defaults to 'std'.
|
||||
:raises ValueError: for vectorial field derivative with respect to
|
||||
all coordinates must be performed.
|
||||
|
||||
:raises NotImplementedError: 'divgrad' not implemented as method.
|
||||
:return: The tensor containing the result of the Laplacian operator.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
def scalar_laplace(output_, input_, components, d):
|
||||
"""
|
||||
Compute Laplace operator for a scalar output.
|
||||
|
||||
:param LabelTensor output_: the output tensor onto which computing the
|
||||
Laplacian. It has to be a column tensor.
|
||||
:param LabelTensor input_: the input tensor with respect to which
|
||||
computing the Laplacian.
|
||||
:param list(str) components: the name of the output variables to
|
||||
calculate the Laplacian for. It should be a subset of the output
|
||||
labels. If None, all the output variables are considered.
|
||||
:param list(str) d: the name of the input variables on which the
|
||||
Laplacian is computed. d should be a subset of the input labels.
|
||||
If None, all the input variables are considered. Default is None.
|
||||
|
||||
:return: The tensor containing the result of the Laplacian operator.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
grad_output = grad(output_, input_, components=components, d=d)
|
||||
result = torch.zeros(output_.shape[0], 1, device=output_.device)
|
||||
|
||||
for i, label in enumerate(grad_output.labels):
|
||||
gg = grad(grad_output, input_, d=d, components=[label])
|
||||
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(i)
|
||||
|
||||
return result
|
||||
|
||||
if d is None:
|
||||
d = input_.labels
|
||||
|
||||
if components is None:
|
||||
components = output_.labels
|
||||
|
||||
if len(components) != len(d) and len(components) != 1:
|
||||
raise ValueError
|
||||
|
||||
if method == "divgrad":
|
||||
raise NotImplementedError("divgrad not implemented as method")
|
||||
# TODO fix
|
||||
# grad_output = grad(output_, input_, components, d)
|
||||
# result = div(grad_output, input_, d=d)
|
||||
elif method == "std":
|
||||
|
||||
elif method == "std":
|
||||
if len(components) == 1:
|
||||
grad_output = grad(output_, input_, components=components, d=d)
|
||||
result = torch.zeros(output_.shape[0], 1, device=output_.device)
|
||||
for i, label in enumerate(grad_output.labels):
|
||||
gg = grad(grad_output, input_, d=d, components=[label])
|
||||
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(
|
||||
i
|
||||
) # TODO improve
|
||||
result = scalar_laplace(output_, input_, components, d)
|
||||
labels = [f"dd{components[0]}"]
|
||||
|
||||
else:
|
||||
result = torch.empty(
|
||||
input_.shape[0], len(components), device=output_.device
|
||||
size=(input_.shape[0], len(components)),
|
||||
dtype=output_.dtype, device=output_.device
|
||||
)
|
||||
labels = [None] * len(components)
|
||||
for idx, (ci, di) in enumerate(zip(components, d)):
|
||||
|
||||
if not isinstance(ci, list):
|
||||
ci = [ci]
|
||||
if not isinstance(di, list):
|
||||
di = [di]
|
||||
|
||||
grad_output = grad(output_, input_, components=ci, d=di)
|
||||
result[:, idx] = grad(grad_output, input_, d=di).flatten()
|
||||
labels[idx] = f"dd{ci}dd{di}"
|
||||
for idx, c in enumerate(components):
|
||||
result[:, idx] = scalar_laplace(output_, input_, c, d).flatten()
|
||||
labels[idx] = f"dd{c}"
|
||||
|
||||
result = result.as_subclass(LabelTensor)
|
||||
result.labels = labels
|
||||
|
||||
Reference in New Issue
Block a user