fix laplacian
This commit is contained in:
@@ -209,49 +209,21 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
|
||||
|
||||
elif method == "std":
|
||||
if len(components) == 1:
|
||||
# result = scalar_laplace(output_, input_, components, d) # TODO check (from 0.1)
|
||||
grad_output = grad(output_, input_, components=components, d=d)
|
||||
to_append_tensors = []
|
||||
for i, label in enumerate(grad_output.labels):
|
||||
gg = grad(grad_output, input_, d=d, components=[label])
|
||||
gg = gg.extract([gg.labels[i]])
|
||||
|
||||
to_append_tensors.append(gg)
|
||||
result = scalar_laplace(output_, input_, components, d)
|
||||
labels = [f"dd{components[0]}"]
|
||||
result = LabelTensor.summation(tensors=to_append_tensors)
|
||||
result.labels = labels
|
||||
else:
|
||||
# result = torch.empty( # TODO check (from 0.1)
|
||||
# size=(input_.shape[0], len(components)),
|
||||
# dtype=output_.dtype,
|
||||
# device=output_.device,
|
||||
# )
|
||||
# labels = [None] * len(components)
|
||||
# for idx, c in enumerate(components):
|
||||
# result[:, idx] = scalar_laplace(output_, input_, c, d).flatten()
|
||||
# labels[idx] = f"dd{c}"
|
||||
|
||||
# result = result.as_subclass(LabelTensor)
|
||||
# result.labels = labels
|
||||
else:
|
||||
result = torch.empty(
|
||||
input_.shape[0], len(components), device=output_.device
|
||||
)
|
||||
labels = [None] * len(components)
|
||||
to_append_tensors = [None] * len(components)
|
||||
for idx, (ci, di) in enumerate(zip(components, d)):
|
||||
for idx, ci in enumerate(components):
|
||||
result[:, idx] = scalar_laplace(output_, input_, ci, d).flatten()
|
||||
labels[idx] = f"dd{ci}"
|
||||
|
||||
if not isinstance(ci, list):
|
||||
ci = [ci]
|
||||
if not isinstance(di, list):
|
||||
di = [di]
|
||||
|
||||
grad_output = grad(output_, input_, components=ci, d=di)
|
||||
result[:, idx] = grad(grad_output, input_, d=di).flatten()
|
||||
to_append_tensors[idx] = grad(grad_output, input_, d=di)
|
||||
labels[idx] = f"dd{ci[0]}dd{di[0]}"
|
||||
result = LabelTensor.cat(tensors=to_append_tensors,
|
||||
dim=output_.tensor.ndim - 1)
|
||||
result = result.as_subclass(LabelTensor)
|
||||
result.labels = labels
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user