Update of LabelTensor class and fix Simplex domain (#362)

*Implement new methods in LabelTensor and fix operators
This commit is contained in:
Filippo Olivo
2024-10-10 18:26:52 +02:00
committed by Nicola Demo
parent fdb8f65143
commit 7528f6ef74
19 changed files with 551 additions and 217 deletions

View File

@@ -27,15 +27,15 @@ def test_grad_scalar_output():
grad_tensor_s = grad(tensor_s, inp)
true_val = 2*inp
assert grad_tensor_s.shape == inp.shape
assert grad_tensor_s.labels[grad_tensor_s.ndim-1]['dof'] == [
f'd{tensor_s.labels[tensor_s.ndim-1]["dof"][0]}d{i}' for i in inp.labels[inp.ndim-1]['dof']
assert grad_tensor_s.labels == [
f'd{tensor_s.labels[0]}d{i}' for i in inp.labels
]
assert torch.allclose(grad_tensor_s, true_val)
grad_tensor_s = grad(tensor_s, inp, d=['x', 'y'])
assert grad_tensor_s.shape == (20, 2)
assert grad_tensor_s.labels[grad_tensor_s.ndim-1]['dof'] == [
f'd{tensor_s.labels[tensor_s.ndim-1]["dof"][0]}d{i}' for i in ['x', 'y']
assert grad_tensor_s.labels == [
f'd{tensor_s.labels[0]}d{i}' for i in ['x', 'y']
]
assert torch.allclose(grad_tensor_s, true_val)