equation class, fix minor bugs, diff domain (#89)
* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
This commit is contained in:
@@ -27,6 +27,7 @@ def test_labels():
|
||||
def test_extract():
|
||||
label_to_extract = ['a', 'c']
|
||||
tensor = LabelTensor(data, labels)
|
||||
print(tensor)
|
||||
new = tensor.extract(label_to_extract)
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
@@ -79,3 +80,11 @@ def test_merge():
|
||||
|
||||
tensor_bc = tensor_b.append(tensor_c)
|
||||
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
|
||||
def test_getitem():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5]
|
||||
|
||||
assert tensor_view.labels == labels
|
||||
assert torch.allclose(tensor_view, data[:5])
|
||||
Reference in New Issue
Block a user