Update of LabelTensor class and fix Simplex domain (#362)
*Implement new methods in LabelTensor and fix operators
This commit is contained in:
committed by
Nicola Demo
parent
fdb8f65143
commit
7528f6ef74
@@ -27,15 +27,15 @@ def test_grad_scalar_output():
|
||||
grad_tensor_s = grad(tensor_s, inp)
|
||||
true_val = 2*inp
|
||||
assert grad_tensor_s.shape == inp.shape
|
||||
assert grad_tensor_s.labels[grad_tensor_s.ndim-1]['dof'] == [
|
||||
f'd{tensor_s.labels[tensor_s.ndim-1]["dof"][0]}d{i}' for i in inp.labels[inp.ndim-1]['dof']
|
||||
assert grad_tensor_s.labels == [
|
||||
f'd{tensor_s.labels[0]}d{i}' for i in inp.labels
|
||||
]
|
||||
assert torch.allclose(grad_tensor_s, true_val)
|
||||
|
||||
grad_tensor_s = grad(tensor_s, inp, d=['x', 'y'])
|
||||
assert grad_tensor_s.shape == (20, 2)
|
||||
assert grad_tensor_s.labels[grad_tensor_s.ndim-1]['dof'] == [
|
||||
f'd{tensor_s.labels[tensor_s.ndim-1]["dof"][0]}d{i}' for i in ['x', 'y']
|
||||
assert grad_tensor_s.labels == [
|
||||
f'd{tensor_s.labels[0]}d{i}' for i in ['x', 'y']
|
||||
]
|
||||
assert torch.allclose(grad_tensor_s, true_val)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user