minor fix
This commit is contained in:
@@ -83,14 +83,19 @@ def div(output_, input_, components=None, d=None):
|
||||
raise ValueError
|
||||
|
||||
grad_output = grad(output_, input_, components, d)
|
||||
div = torch.empty(input_.shape[0], len(components))
|
||||
div = torch.zeros(input_.shape[0], 1)
|
||||
# print(grad_output)
|
||||
# print('empty', div)
|
||||
labels = [None] * len(components)
|
||||
for i, c in enumerate(components):
|
||||
c_fields = [f'd{c}d{di}' for di in d]
|
||||
div[:, i] = grad_output.extract(c_fields).sum(axis=1)
|
||||
labels[i] = '+'.join(c_fields)
|
||||
for i, (c, d) in enumerate(zip(components, d)):
|
||||
c_fields = f'd{c}d{d}'
|
||||
# print(c_fields)
|
||||
div[:, 0] += grad_output.extract(c_fields).sum(axis=1)
|
||||
labels[i] = c_fields
|
||||
# print('full', div)
|
||||
# print(labels)
|
||||
|
||||
return LabelTensor(div, labels)
|
||||
return LabelTensor(div, ['+'.join(labels)])
|
||||
|
||||
|
||||
def nabla(output_, input_, components=None, d=None, method='std'):
|
||||
|
||||
Reference in New Issue
Block a user