Files
PINA/tests/test_label_tensor.py
Nicola Demo bd88e24174 fix slicing for LabelTensor (#167)
* fix slicing for LabelTensor
* Update testing_pr.yml for solving python3.1 error
2023-11-17 09:51:29 +01:00

104 lines
2.8 KiB
Python

import torch
import pytest
from pina import LabelTensor
data = torch.rand((20, 3))
labels = ['a', 'b', 'c']
def test_constructor():
LabelTensor(data, labels)
def test_wrong_constructor():
with pytest.raises(ValueError):
LabelTensor(data, ['a', 'b'])
def test_labels():
tensor = LabelTensor(data, labels)
assert isinstance(tensor, torch.Tensor)
assert tensor.labels == labels
with pytest.raises(ValueError):
tensor.labels = labels[:-1]
def test_extract():
label_to_extract = ['a', 'c']
tensor = LabelTensor(data, labels)
print(tensor)
new = tensor.extract(label_to_extract)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(data[:, 0::2], new))
def test_extract_onelabel():
label_to_extract = ['a']
tensor = LabelTensor(data, labels)
new = tensor.extract(label_to_extract)
assert new.ndim == 2
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
def test_wrong_extract():
label_to_extract = ['a', 'cc']
tensor = LabelTensor(data, labels)
with pytest.raises(ValueError):
tensor.extract(label_to_extract)
def test_extract_order():
label_to_extract = ['c', 'a']
tensor = LabelTensor(data, labels)
new = tensor.extract(label_to_extract)
expected = torch.cat(
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
dim=1)
print(expected)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new))
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
def test_merge2():
tensor = LabelTensor(data, labels)
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
def test_getitem():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5]
assert tensor_view.labels == labels
assert torch.allclose(tensor_view, data[:5])
def test_slice():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5, :2]
assert tensor_view.labels == labels[:2]
assert torch.allclose(tensor_view, data[:5, :2])
tensor_view2 = tensor[3]
assert tensor_view2.labels == labels
assert torch.allclose(tensor_view2, data[3])
tensor_view3 = tensor[:, 2]
assert tensor_view3.labels == labels[2]
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))