fix tests

This commit is contained in:
Nicola Demo
2025-01-23 09:52:23 +01:00
parent 9aed1a30b3
commit a899327de1
32 changed files with 2331 additions and 2428 deletions

View File

@@ -26,6 +26,7 @@ tensor_s = LabelTensor(func_scalar(inp).reshape(-1, 1), labels[0])
def test_grad_scalar_output():
grad_tensor_s = grad(tensor_s, inp)
true_val = 2*inp
true_val.labels = inp.labels
assert grad_tensor_s.shape == inp.shape
assert grad_tensor_s.labels == [
f'd{tensor_s.labels[0]}d{i}' for i in inp.labels
@@ -37,7 +38,7 @@ def test_grad_scalar_output():
assert grad_tensor_s.labels == [
f'd{tensor_s.labels[0]}d{i}' for i in ['x', 'y']
]
assert torch.allclose(grad_tensor_s, true_val)
assert torch.allclose(grad_tensor_s, true_val.extract(['x', 'y']))
def test_grad_vector_output():