fix laplacian

This commit is contained in:
giovanni
2025-01-20 13:53:08 +01:00
committed by Nicola Demo
parent a6f0336d06
commit 81830ecc99

View File

@@ -209,49 +209,21 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
elif method == "std":
if len(components) == 1:
# result = scalar_laplace(output_, input_, components, d) # TODO check (from 0.1)
grad_output = grad(output_, input_, components=components, d=d)
to_append_tensors = []
for i, label in enumerate(grad_output.labels):
gg = grad(grad_output, input_, d=d, components=[label])
gg = gg.extract([gg.labels[i]])
to_append_tensors.append(gg)
result = scalar_laplace(output_, input_, components, d)
labels = [f"dd{components[0]}"]
result = LabelTensor.summation(tensors=to_append_tensors)
result.labels = labels
else:
# result = torch.empty( # TODO check (from 0.1)
# size=(input_.shape[0], len(components)),
# dtype=output_.dtype,
# device=output_.device,
# )
# labels = [None] * len(components)
# for idx, c in enumerate(components):
# result[:, idx] = scalar_laplace(output_, input_, c, d).flatten()
# labels[idx] = f"dd{c}"
# result = result.as_subclass(LabelTensor)
# result.labels = labels
else:
result = torch.empty(
input_.shape[0], len(components), device=output_.device
)
labels = [None] * len(components)
to_append_tensors = [None] * len(components)
for idx, (ci, di) in enumerate(zip(components, d)):
for idx, ci in enumerate(components):
result[:, idx] = scalar_laplace(output_, input_, ci, d).flatten()
labels[idx] = f"dd{ci}"
if not isinstance(ci, list):
ci = [ci]
if not isinstance(di, list):
di = [di]
grad_output = grad(output_, input_, components=ci, d=di)
result[:, idx] = grad(grad_output, input_, d=di).flatten()
to_append_tensors[idx] = grad(grad_output, input_, d=di)
labels[idx] = f"dd{ci[0]}dd{di[0]}"
result = LabelTensor.cat(tensors=to_append_tensors,
dim=output_.tensor.ndim - 1)
result = result.as_subclass(LabelTensor)
result.labels = labels
return result