53 lines
1.7 KiB
Python
53 lines
1.7 KiB
Python
import torch
|
|
|
|
from pina.utils import merge_tensors
|
|
from pina.label_tensor import LabelTensor
|
|
from pina import LabelTensor
|
|
from pina.domain import EllipsoidDomain, CartesianDomain
|
|
from pina.utils import check_consistency
|
|
import pytest
|
|
from pina.domain import DomainInterface
|
|
|
|
|
|
def test_merge_tensors():
|
|
tensor1 = LabelTensor(torch.rand((20, 3)), ["a", "b", "c"])
|
|
tensor2 = LabelTensor(torch.zeros((20, 3)), ["d", "e", "f"])
|
|
tensor3 = LabelTensor(torch.ones((30, 3)), ["g", "h", "i"])
|
|
|
|
merged_tensor = merge_tensors((tensor1, tensor2, tensor3))
|
|
assert tuple(merged_tensor.labels) == (
|
|
"a",
|
|
"b",
|
|
"c",
|
|
"d",
|
|
"e",
|
|
"f",
|
|
"g",
|
|
"h",
|
|
"i",
|
|
)
|
|
assert merged_tensor.shape == (20 * 20 * 30, 9)
|
|
assert torch.all(merged_tensor.extract(("d", "e", "f")) == 0)
|
|
assert torch.all(merged_tensor.extract(("g", "h", "i")) == 1)
|
|
|
|
|
|
def test_check_consistency_correct():
|
|
ellipsoid1 = EllipsoidDomain({"x": [1, 2], "y": [-2, 1]})
|
|
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"])
|
|
|
|
check_consistency(example_input_pts, torch.Tensor)
|
|
check_consistency(CartesianDomain, DomainInterface, subclass=True)
|
|
check_consistency(ellipsoid1, DomainInterface)
|
|
|
|
|
|
def test_check_consistency_incorrect():
|
|
ellipsoid1 = EllipsoidDomain({"x": [1, 2], "y": [-2, 1]})
|
|
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"])
|
|
|
|
with pytest.raises(ValueError):
|
|
check_consistency(example_input_pts, DomainInterface)
|
|
with pytest.raises(ValueError):
|
|
check_consistency(torch.Tensor, DomainInterface, subclass=True)
|
|
with pytest.raises(ValueError):
|
|
check_consistency(ellipsoid1, torch.Tensor)
|