diff --git a/tests/test_operators.py b/tests/test_operators.py index 35c0791..ccfc17b 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -17,9 +17,9 @@ def func_scalar(x): data = torch.rand((20, 3)) -inp = LabelTensor(data, ['x', 'y', 'mu']).requires_grad_(True) +inp = LabelTensor(data, ['x', 'y', 'z']).requires_grad_(True) labels = ['a', 'b', 'c'] -tensor_v = LabelTensor(func_vec(inp), labels) +tensor_v = LabelTensor(func_vector(inp), labels) tensor_s = LabelTensor(func_scalar(inp).reshape(-1, 1), labels[0]) @@ -107,6 +107,8 @@ def test_laplacian_scalar_output(): def test_laplacian_vector_output(): laplace_tensor_v = laplacian(tensor_v, inp) + print(laplace_tensor_v.labels) + print(tensor_v.labels) true_val = 2*torch.ones_like(tensor_v) assert laplace_tensor_v.shape == tensor_v.shape assert laplace_tensor_v.labels == [ @@ -124,3 +126,30 @@ def test_laplacian_vector_output(): f'dd{i}' for i in ['a', 'b'] ] assert torch.allclose(laplace_tensor_v, true_val) + +def test_laplacian_vector_output2(): + x = LabelTensor(torch.linspace(0,1,10, requires_grad=True).reshape(-1,1), labels = ['x']) + y = LabelTensor(torch.linspace(3,4,10, requires_grad=True).reshape(-1,1), labels = ['y']) + input_ = LabelTensor(torch.cat((x,y), dim = 1), labels = ['x', 'y']) + + # Construct two scalar functions: + # u = x**2 + y**2 + # v = x**2 - y**2 + u = LabelTensor(input_.extract('x')**2 + input_.extract('y')**2, labels='u') + v = LabelTensor(input_.extract('x')**2 - input_.extract('y')**2, labels='v') + + # Define a vector-valued function, whose components are u and v. + f = LabelTensor(torch.cat((u,v), dim = 1), labels = ['u', 'v']) + + # Compute the scalar laplacian of both u and v: + # Lap(u) = [4, 4, 4, ..., 4] + # Lap(v) = [0, 0, 0, ..., 0] + lap_u = laplacian(u, input_, components=['u']) + lap_v = laplacian(v, input_, components=['v']) + + # Compute the laplacian of f: the two columns should correspond + # to the laplacians of u and v, respectively... + lap_f = laplacian(f, input_, components=['u', 'v']) + + assert torch.allclose(lap_f.extract('ddu'), lap_u) + assert torch.allclose(lap_f.extract('ddv'), lap_v)