Implement definition of LabelTensor from list, implement cat method (previously stack) and re-implement extract

This commit is contained in:
FilippoOlivo
2024-09-28 12:23:16 +02:00
committed by Nicola Demo
parent a779007b36
commit c53c3d5b84
2 changed files with 105 additions and 473 deletions

View File

@@ -38,7 +38,7 @@ def test_extract_column(labels, labels_te):
assert torch.all(torch.isclose(data[:, 2].reshape(-1, 1), new))
@pytest.mark.parametrize("labels", [labels_row, labels_all])
@pytest.mark.parametrize("labels_te", [2, [2], {'samples': [2]}])
@pytest.mark.parametrize("labels_te", [{'samples': [2]}])
def test_extract_row(labels, labels_te):
tensor = LabelTensor(data, labels)
new = tensor.extract(labels_te)
@@ -62,7 +62,7 @@ def test_extract_2D(labels_te):
def test_extract_3D():
labels = labels_all
data = torch.rand((20, 3, 4))
data = torch.rand(20, 3, 4)
labels = {
1: {
"name": "space",
@@ -77,6 +77,7 @@ def test_extract_3D():
'space': ['x', 'z'],
'time': range(1, 4)
}
tensor = LabelTensor(data, labels)
new = tensor.extract(labels_te)
assert new.ndim == tensor.ndim
@@ -86,106 +87,4 @@ def test_extract_3D():
assert torch.all(torch.isclose(
data[:, 0::2, 1:4].reshape(20, 2, 3),
new
))
# def test_labels():
# tensor = LabelTensor(data, labels)
# assert isinstance(tensor, torch.Tensor)
# assert tensor.labels == labels
# with pytest.raises(ValueError):
# tensor.labels = labels[:-1]
# def test_extract():
# label_to_extract = ['a', 'c']
# tensor = LabelTensor(data, labels)
# new = tensor.extract(label_to_extract)
# assert new.labels == label_to_extract
# assert new.shape[1] == len(label_to_extract)
# assert torch.all(torch.isclose(data[:, 0::2], new))
# def test_extract_onelabel():
# label_to_extract = ['a']
# tensor = LabelTensor(data, labels)
# new = tensor.extract(label_to_extract)
# assert new.ndim == 2
# assert new.labels == label_to_extract
# assert new.shape[1] == len(label_to_extract)
# assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
# def test_wrong_extract():
# label_to_extract = ['a', 'cc']
# tensor = LabelTensor(data, labels)
# with pytest.raises(ValueError):
# tensor.extract(label_to_extract)
# def test_extract_order():
# label_to_extract = ['c', 'a']
# tensor = LabelTensor(data, labels)
# new = tensor.extract(label_to_extract)
# expected = torch.cat(
# (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
# dim=1)
# assert new.labels == label_to_extract
# assert new.shape[1] == len(label_to_extract)
# assert torch.all(torch.isclose(expected, new))
# def test_merge():
# tensor = LabelTensor(data, labels)
# tensor_a = tensor.extract('a')
# tensor_b = tensor.extract('b')
# tensor_c = tensor.extract('c')
# tensor_bc = tensor_b.append(tensor_c)
# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
# def test_merge2():
# tensor = LabelTensor(data, labels)
# tensor_b = tensor.extract('b')
# tensor_c = tensor.extract('c')
# tensor_bc = tensor_b.append(tensor_c)
# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
# def test_getitem():
# tensor = LabelTensor(data, labels)
# tensor_view = tensor['a']
# assert tensor_view.labels == ['a']
# assert torch.allclose(tensor_view.flatten(), data[:, 0])
# tensor_view = tensor['a', 'c']
# assert tensor_view.labels == ['a', 'c']
# assert torch.allclose(tensor_view, data[:, 0::2])
# def test_getitem2():
# tensor = LabelTensor(data, labels)
# tensor_view = tensor[:5]
# assert tensor_view.labels == labels
# assert torch.allclose(tensor_view, data[:5])
# idx = torch.randperm(tensor.shape[0])
# tensor_view = tensor[idx]
# assert tensor_view.labels == labels
# def test_slice():
# tensor = LabelTensor(data, labels)
# tensor_view = tensor[:5, :2]
# assert tensor_view.labels == labels[:2]
# assert torch.allclose(tensor_view, data[:5, :2])
# tensor_view2 = tensor[3]
# assert tensor_view2.labels == labels
# assert torch.allclose(tensor_view2, data[3])
# tensor_view3 = tensor[:, 2]
# assert tensor_view3.labels == labels[2]
# assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
))