Additional improvement related to #395
This commit is contained in:
committed by
Nicola Demo
parent
afb1bca245
commit
629a6ee43b
@@ -281,3 +281,17 @@ def test_sorting():
|
||||
assert torch.eq(lt_sorted.tensor[:, 1, :], torch.ones(20, 5) * 2).all()
|
||||
assert torch.eq(lt_sorted.tensor[:, 2, :], torch.ones(20, 5) * 3).all()
|
||||
assert torch.eq(lt_sorted.tensor[:, 3, :], torch.ones(20, 5) * 4).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("labels",
|
||||
[[f's{i}' for i in range(10)],
|
||||
{0: {'dof': ['a', 'b', 'c']},
|
||||
1: {'dof': [f's{i}' for i in range(10)]}}])
|
||||
def test_cat_bool(labels):
|
||||
out = torch.randn((3, 10))
|
||||
out = LabelTensor(out, labels)
|
||||
selected = out[torch.tensor([True, True, False])]
|
||||
assert selected.shape == (2, 10)
|
||||
assert selected.stored_labels[1]['dof'] == [f's{i}' for i in range(10)]
|
||||
if isinstance(labels, dict):
|
||||
assert selected.stored_labels[0]['dof'] == ['a', 'b']
|
||||
|
||||
Reference in New Issue
Block a user