Label Tensor update (#188)
* Update test_label_tensor.py * adding test --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
committed by
Nicola Demo
parent
d556c592e0
commit
cd5bc9a558
@@ -27,7 +27,6 @@ def test_labels():
|
||||
def test_extract():
|
||||
label_to_extract = ['a', 'c']
|
||||
tensor = LabelTensor(data, labels)
|
||||
print(tensor)
|
||||
new = tensor.extract(label_to_extract)
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
@@ -58,7 +57,6 @@ def test_extract_order():
|
||||
expected = torch.cat(
|
||||
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
|
||||
dim=1)
|
||||
print(expected)
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
assert torch.all(torch.isclose(expected, new))
|
||||
@@ -83,6 +81,18 @@ def test_merge2():
|
||||
|
||||
|
||||
def test_getitem():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor['a']
|
||||
|
||||
assert tensor_view.labels == ['a']
|
||||
assert torch.allclose(tensor_view.flatten(), data[:, 0])
|
||||
|
||||
tensor_view = tensor['a', 'c']
|
||||
|
||||
assert tensor_view.labels == ['a', 'c']
|
||||
assert torch.allclose(tensor_view, data[:, 0::2])
|
||||
|
||||
def test_getitem2():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5]
|
||||
|
||||
@@ -101,4 +111,4 @@ def test_slice():
|
||||
|
||||
tensor_view3 = tensor[:, 2]
|
||||
assert tensor_view3.labels == labels[2]
|
||||
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||
|
||||
Reference in New Issue
Block a user