add dataset and dataloader for sample points (#195)
* add dataset and dataloader for sample points * unittests
This commit is contained in:
@@ -95,10 +95,14 @@ def test_getitem():
|
||||
def test_getitem2():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5]
|
||||
|
||||
assert tensor_view.labels == labels
|
||||
assert torch.allclose(tensor_view, data[:5])
|
||||
|
||||
idx = torch.randperm(tensor.shape[0])
|
||||
tensor_view = tensor[idx]
|
||||
assert tensor_view.labels == labels
|
||||
|
||||
|
||||
def test_slice():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5, :2]
|
||||
|
||||
Reference in New Issue
Block a user