fix slicing for LabelTensor (#167)
* fix slicing for LabelTensor * Update testing_pr.yml for solving python3.1 error
This commit is contained in:
@@ -73,7 +73,7 @@ def test_merge():
|
||||
tensor_bc = tensor_b.append(tensor_c)
|
||||
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
def test_merge():
|
||||
def test_merge2():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_b = tensor.extract('b')
|
||||
tensor_c = tensor.extract('c')
|
||||
@@ -87,4 +87,18 @@ def test_getitem():
|
||||
tensor_view = tensor[:5]
|
||||
|
||||
assert tensor_view.labels == labels
|
||||
assert torch.allclose(tensor_view, data[:5])
|
||||
assert torch.allclose(tensor_view, data[:5])
|
||||
|
||||
def test_slice():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5, :2]
|
||||
assert tensor_view.labels == labels[:2]
|
||||
assert torch.allclose(tensor_view, data[:5, :2])
|
||||
|
||||
tensor_view2 = tensor[3]
|
||||
assert tensor_view2.labels == labels
|
||||
assert torch.allclose(tensor_view2, data[3])
|
||||
|
||||
tensor_view3 = tensor[:, 2]
|
||||
assert tensor_view3.labels == labels[2]
|
||||
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||
Reference in New Issue
Block a user