equation class, fix minor bugs, diff domain (#89)

* equation class
* difference domain
* dummy dataloader
* writer class
* refactoring and minor fix
This commit is contained in:
Nicola Demo
2023-05-15 16:06:01 +02:00
parent be11110bb2
commit 0e3625de80
25 changed files with 691 additions and 246 deletions

View File

@@ -27,6 +27,7 @@ def test_labels():
def test_extract():
label_to_extract = ['a', 'c']
tensor = LabelTensor(data, labels)
print(tensor)
new = tensor.extract(label_to_extract)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
@@ -79,3 +80,11 @@ def test_merge():
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
def test_getitem():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5]
assert tensor_view.labels == labels
assert torch.allclose(tensor_view, data[:5])