use LabelTensor, fix minor, docs

This commit is contained in:
Your Name
2022-03-29 18:05:26 +02:00
parent 12f4084d7f
commit 6b001c6c53
19 changed files with 370 additions and 322 deletions

View File

@@ -59,3 +59,22 @@ def test_extract_order():
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new))
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bb = tensor_b.append(tensor_b)
assert torch.allclose(tensor_b, tensor.extract(['b', 'c']))