minor fix

This commit is contained in:
Your Name
2022-07-20 17:23:53 +02:00
committed by Nicola Demo
parent 75a81af99c
commit a05adea4e3
10 changed files with 231 additions and 203 deletions

View File

@@ -83,14 +83,19 @@ def div(output_, input_, components=None, d=None):
raise ValueError
grad_output = grad(output_, input_, components, d)
div = torch.empty(input_.shape[0], len(components))
div = torch.zeros(input_.shape[0], 1)
# print(grad_output)
# print('empty', div)
labels = [None] * len(components)
for i, c in enumerate(components):
c_fields = [f'd{c}d{di}' for di in d]
div[:, i] = grad_output.extract(c_fields).sum(axis=1)
labels[i] = '+'.join(c_fields)
for i, (c, d) in enumerate(zip(components, d)):
c_fields = f'd{c}d{d}'
# print(c_fields)
div[:, 0] += grad_output.extract(c_fields).sum(axis=1)
labels[i] = c_fields
# print('full', div)
# print(labels)
return LabelTensor(div, labels)
return LabelTensor(div, ['+'.join(labels)])
def nabla(output_, input_, components=None, d=None, method='std'):