Files
PINA/tests/test_utils.py
Filippo Olivo 4177bfbb50 Fix Codacy Warnings (#477)
---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
2025-03-19 17:48:18 +01:00

53 lines
1.7 KiB
Python

import torch
from pina.utils import merge_tensors
from pina.label_tensor import LabelTensor
from pina import LabelTensor
from pina.domain import EllipsoidDomain, CartesianDomain
from pina.utils import check_consistency
import pytest
from pina.domain import DomainInterface
def test_merge_tensors():
tensor1 = LabelTensor(torch.rand((20, 3)), ["a", "b", "c"])
tensor2 = LabelTensor(torch.zeros((20, 3)), ["d", "e", "f"])
tensor3 = LabelTensor(torch.ones((30, 3)), ["g", "h", "i"])
merged_tensor = merge_tensors((tensor1, tensor2, tensor3))
assert tuple(merged_tensor.labels) == (
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
)
assert merged_tensor.shape == (20 * 20 * 30, 9)
assert torch.all(merged_tensor.extract(("d", "e", "f")) == 0)
assert torch.all(merged_tensor.extract(("g", "h", "i")) == 1)
def test_check_consistency_correct():
ellipsoid1 = EllipsoidDomain({"x": [1, 2], "y": [-2, 1]})
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"])
check_consistency(example_input_pts, torch.Tensor)
check_consistency(CartesianDomain, DomainInterface, subclass=True)
check_consistency(ellipsoid1, DomainInterface)
def test_check_consistency_incorrect():
ellipsoid1 = EllipsoidDomain({"x": [1, 2], "y": [-2, 1]})
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"])
with pytest.raises(ValueError):
check_consistency(example_input_pts, DomainInterface)
with pytest.raises(ValueError):
check_consistency(torch.Tensor, DomainInterface, subclass=True)
with pytest.raises(ValueError):
check_consistency(ellipsoid1, torch.Tensor)