fix tests
This commit is contained in:
@@ -26,6 +26,7 @@ tensor_s = LabelTensor(func_scalar(inp).reshape(-1, 1), labels[0])
|
||||
def test_grad_scalar_output():
|
||||
grad_tensor_s = grad(tensor_s, inp)
|
||||
true_val = 2*inp
|
||||
true_val.labels = inp.labels
|
||||
assert grad_tensor_s.shape == inp.shape
|
||||
assert grad_tensor_s.labels == [
|
||||
f'd{tensor_s.labels[0]}d{i}' for i in inp.labels
|
||||
@@ -37,7 +38,7 @@ def test_grad_scalar_output():
|
||||
assert grad_tensor_s.labels == [
|
||||
f'd{tensor_s.labels[0]}d{i}' for i in ['x', 'y']
|
||||
]
|
||||
assert torch.allclose(grad_tensor_s, true_val)
|
||||
assert torch.allclose(grad_tensor_s, true_val.extract(['x', 'y']))
|
||||
|
||||
|
||||
def test_grad_vector_output():
|
||||
|
||||
Reference in New Issue
Block a user