in progress
This commit is contained in:
@@ -1,119 +1,191 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina import LabelTensor
|
||||
from pina.label_tensor import LabelTensor
|
||||
#import pina
|
||||
|
||||
data = torch.rand((20, 3))
|
||||
labels = ['a', 'b', 'c']
|
||||
labels_column = {
|
||||
1: {
|
||||
"name": "space",
|
||||
"dof": ['x', 'y', 'z']
|
||||
}
|
||||
}
|
||||
labels_row = {
|
||||
0: {
|
||||
"name": "samples",
|
||||
"dof": range(20)
|
||||
}
|
||||
}
|
||||
labels_all = labels_column | labels_row
|
||||
|
||||
|
||||
def test_constructor():
|
||||
@pytest.mark.parametrize("labels", [labels_column, labels_row, labels_all])
|
||||
def test_constructor(labels):
|
||||
LabelTensor(data, labels)
|
||||
|
||||
|
||||
def test_wrong_constructor():
|
||||
with pytest.raises(ValueError):
|
||||
LabelTensor(data, ['a', 'b'])
|
||||
|
||||
|
||||
def test_labels():
|
||||
@pytest.mark.parametrize("labels", [labels_column, labels_all])
|
||||
@pytest.mark.parametrize("labels_te", ['z', ['z'], {'space': ['z']}])
|
||||
def test_extract_column(labels, labels_te):
|
||||
tensor = LabelTensor(data, labels)
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.labels == labels
|
||||
with pytest.raises(ValueError):
|
||||
tensor.labels = labels[:-1]
|
||||
new = tensor.extract(labels_te)
|
||||
assert new.ndim == tensor.ndim
|
||||
assert new.shape[1] == 1
|
||||
assert new.shape[0] == 20
|
||||
assert torch.all(torch.isclose(data[:, 2].reshape(-1, 1), new))
|
||||
|
||||
|
||||
def test_extract():
|
||||
label_to_extract = ['a', 'c']
|
||||
@pytest.mark.parametrize("labels", [labels_row, labels_all])
|
||||
@pytest.mark.parametrize("labels_te", [2, [2], {'samples': [2]}])
|
||||
def test_extract_row(labels, labels_te):
|
||||
tensor = LabelTensor(data, labels)
|
||||
new = tensor.extract(label_to_extract)
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
assert torch.all(torch.isclose(data[:, 0::2], new))
|
||||
new = tensor.extract(labels_te)
|
||||
assert new.ndim == tensor.ndim
|
||||
assert new.shape[1] == 3
|
||||
assert new.shape[0] == 1
|
||||
assert torch.all(torch.isclose(data[2].reshape(1, -1), new))
|
||||
|
||||
|
||||
def test_extract_onelabel():
|
||||
label_to_extract = ['a']
|
||||
@pytest.mark.parametrize("labels_te", [
|
||||
{'samples': [2], 'space': ['z']},
|
||||
{'space': 'z', 'samples': 2}
|
||||
])
|
||||
def test_extract_2D(labels_te):
|
||||
labels = labels_all
|
||||
tensor = LabelTensor(data, labels)
|
||||
new = tensor.extract(label_to_extract)
|
||||
assert new.ndim == 2
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
|
||||
new = tensor.extract(labels_te)
|
||||
assert new.ndim == tensor.ndim
|
||||
assert new.shape[1] == 1
|
||||
assert new.shape[0] == 1
|
||||
assert torch.all(torch.isclose(data[2,2].reshape(1, 1), new))
|
||||
|
||||
|
||||
def test_wrong_extract():
|
||||
label_to_extract = ['a', 'cc']
|
||||
def test_extract_3D():
|
||||
labels = labels_all
|
||||
data = torch.rand((20, 3, 4))
|
||||
labels = {
|
||||
1: {
|
||||
"name": "space",
|
||||
"dof": ['x', 'y', 'z']
|
||||
},
|
||||
2: {
|
||||
"name": "time",
|
||||
"dof": range(4)
|
||||
},
|
||||
}
|
||||
labels_te = {
|
||||
'space': ['x', 'z'],
|
||||
'time': range(1, 4)
|
||||
}
|
||||
tensor = LabelTensor(data, labels)
|
||||
with pytest.raises(ValueError):
|
||||
tensor.extract(label_to_extract)
|
||||
new = tensor.extract(labels_te)
|
||||
assert new.ndim == tensor.ndim
|
||||
assert new.shape[0] == 20
|
||||
assert new.shape[1] == 2
|
||||
assert new.shape[2] == 3
|
||||
assert torch.all(torch.isclose(
|
||||
data[:, 0::2, 1:4].reshape(20, 2, 3),
|
||||
new
|
||||
))
|
||||
|
||||
# def test_labels():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# assert isinstance(tensor, torch.Tensor)
|
||||
# assert tensor.labels == labels
|
||||
# with pytest.raises(ValueError):
|
||||
# tensor.labels = labels[:-1]
|
||||
|
||||
|
||||
def test_extract_order():
|
||||
label_to_extract = ['c', 'a']
|
||||
tensor = LabelTensor(data, labels)
|
||||
new = tensor.extract(label_to_extract)
|
||||
expected = torch.cat(
|
||||
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
|
||||
dim=1)
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
assert torch.all(torch.isclose(expected, new))
|
||||
# def test_extract():
|
||||
# label_to_extract = ['a', 'c']
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# new = tensor.extract(label_to_extract)
|
||||
# assert new.labels == label_to_extract
|
||||
# assert new.shape[1] == len(label_to_extract)
|
||||
# assert torch.all(torch.isclose(data[:, 0::2], new))
|
||||
|
||||
|
||||
def test_merge():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_a = tensor.extract('a')
|
||||
tensor_b = tensor.extract('b')
|
||||
tensor_c = tensor.extract('c')
|
||||
|
||||
tensor_bc = tensor_b.append(tensor_c)
|
||||
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
# def test_extract_onelabel():
|
||||
# label_to_extract = ['a']
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# new = tensor.extract(label_to_extract)
|
||||
# assert new.ndim == 2
|
||||
# assert new.labels == label_to_extract
|
||||
# assert new.shape[1] == len(label_to_extract)
|
||||
# assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
|
||||
|
||||
|
||||
def test_merge2():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_b = tensor.extract('b')
|
||||
tensor_c = tensor.extract('c')
|
||||
|
||||
tensor_bc = tensor_b.append(tensor_c)
|
||||
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
# def test_wrong_extract():
|
||||
# label_to_extract = ['a', 'cc']
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# with pytest.raises(ValueError):
|
||||
# tensor.extract(label_to_extract)
|
||||
|
||||
|
||||
def test_getitem():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor['a']
|
||||
|
||||
assert tensor_view.labels == ['a']
|
||||
assert torch.allclose(tensor_view.flatten(), data[:, 0])
|
||||
|
||||
tensor_view = tensor['a', 'c']
|
||||
|
||||
assert tensor_view.labels == ['a', 'c']
|
||||
assert torch.allclose(tensor_view, data[:, 0::2])
|
||||
|
||||
def test_getitem2():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5]
|
||||
assert tensor_view.labels == labels
|
||||
assert torch.allclose(tensor_view, data[:5])
|
||||
|
||||
idx = torch.randperm(tensor.shape[0])
|
||||
tensor_view = tensor[idx]
|
||||
assert tensor_view.labels == labels
|
||||
# def test_extract_order():
|
||||
# label_to_extract = ['c', 'a']
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# new = tensor.extract(label_to_extract)
|
||||
# expected = torch.cat(
|
||||
# (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
|
||||
# dim=1)
|
||||
# assert new.labels == label_to_extract
|
||||
# assert new.shape[1] == len(label_to_extract)
|
||||
# assert torch.all(torch.isclose(expected, new))
|
||||
|
||||
|
||||
def test_slice():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5, :2]
|
||||
assert tensor_view.labels == labels[:2]
|
||||
assert torch.allclose(tensor_view, data[:5, :2])
|
||||
# def test_merge():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_a = tensor.extract('a')
|
||||
# tensor_b = tensor.extract('b')
|
||||
# tensor_c = tensor.extract('c')
|
||||
|
||||
tensor_view2 = tensor[3]
|
||||
assert tensor_view2.labels == labels
|
||||
assert torch.allclose(tensor_view2, data[3])
|
||||
# tensor_bc = tensor_b.append(tensor_c)
|
||||
# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
tensor_view3 = tensor[:, 2]
|
||||
assert tensor_view3.labels == labels[2]
|
||||
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||
|
||||
# def test_merge2():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_b = tensor.extract('b')
|
||||
# tensor_c = tensor.extract('c')
|
||||
|
||||
# tensor_bc = tensor_b.append(tensor_c)
|
||||
# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
|
||||
# def test_getitem():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_view = tensor['a']
|
||||
|
||||
# assert tensor_view.labels == ['a']
|
||||
# assert torch.allclose(tensor_view.flatten(), data[:, 0])
|
||||
|
||||
# tensor_view = tensor['a', 'c']
|
||||
|
||||
# assert tensor_view.labels == ['a', 'c']
|
||||
# assert torch.allclose(tensor_view, data[:, 0::2])
|
||||
|
||||
# def test_getitem2():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_view = tensor[:5]
|
||||
# assert tensor_view.labels == labels
|
||||
# assert torch.allclose(tensor_view, data[:5])
|
||||
|
||||
# idx = torch.randperm(tensor.shape[0])
|
||||
# tensor_view = tensor[idx]
|
||||
# assert tensor_view.labels == labels
|
||||
|
||||
|
||||
# def test_slice():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_view = tensor[:5, :2]
|
||||
# assert tensor_view.labels == labels[:2]
|
||||
# assert torch.allclose(tensor_view, data[:5, :2])
|
||||
|
||||
# tensor_view2 = tensor[3]
|
||||
# assert tensor_view2.labels == labels
|
||||
# assert torch.allclose(tensor_view2, data[3])
|
||||
|
||||
# tensor_view3 = tensor[:, 2]
|
||||
# assert tensor_view3.labels == labels[2]
|
||||
# assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||
|
||||
Reference in New Issue
Block a user