Fix __getitem__ in LabelTensor (#546)
* Fix LabelTensor * Cleaning label_tensor.py --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
FilippoOlivo
parent
e250e3f5f7
commit
a1f98764d3
@@ -278,3 +278,63 @@ def test_cat_bool(labels):
|
||||
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
|
||||
if isinstance(labels, dict):
|
||||
assert selected.stored_labels[0]["dof"] == ["a", "b"]
|
||||
|
||||
|
||||
def test_getitem_int():
|
||||
data = torch.rand(20, 3)
|
||||
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
|
||||
lt = LabelTensor(data, labels)
|
||||
new = lt[0, 0]
|
||||
assert new.ndim == 1
|
||||
assert new.shape[0] == 1
|
||||
assert torch.all(torch.isclose(data[0, 0], new))
|
||||
|
||||
data = torch.rand(20, 3, 2)
|
||||
labels = {
|
||||
1: {"name": 1, "dof": ["x", "y", "z"]},
|
||||
2: {"name": 2, "dof": ["a", "b"]},
|
||||
}
|
||||
lt = LabelTensor(data, labels)
|
||||
new = lt[0, 0, 0]
|
||||
assert new.ndim == 2
|
||||
assert new.shape[0] == 1
|
||||
assert new.shape[1] == 1
|
||||
assert torch.all(torch.isclose(data[0, 0, 0], new))
|
||||
assert new.stored_labels[0]["dof"] == ["x"]
|
||||
assert new.stored_labels[1]["dof"] == ["a"]
|
||||
|
||||
new = lt[0, 0, :]
|
||||
assert new.ndim == 2
|
||||
assert new.shape[0] == 1
|
||||
assert new.shape[1] == 2
|
||||
assert torch.all(torch.isclose(data[0, 0, :], new))
|
||||
assert new.stored_labels[0]["dof"] == ["x"]
|
||||
assert new.stored_labels[1]["dof"] == ["a", "b"]
|
||||
|
||||
new = lt[0, :, 1]
|
||||
assert new.ndim == 2
|
||||
assert new.shape[0] == 3
|
||||
assert new.shape[1] == 1
|
||||
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
|
||||
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
|
||||
assert new.stored_labels[1]["dof"] == ["b"]
|
||||
|
||||
labels.pop(2)
|
||||
lt = LabelTensor(data, labels)
|
||||
new = lt[0, 0, 0]
|
||||
assert new.ndim == 1
|
||||
assert new.shape[0] == 1
|
||||
assert new.stored_labels[0]["dof"] == ["x"]
|
||||
|
||||
new = lt[:, 0, 0]
|
||||
assert new.ndim == 2
|
||||
assert new.shape[0] == 20
|
||||
assert new.shape[1] == 1
|
||||
assert new.stored_labels[1]["dof"] == ["x"]
|
||||
|
||||
new = lt[:, 0, :]
|
||||
assert new.ndim == 3
|
||||
assert new.shape[0] == 20
|
||||
assert new.shape[1] == 1
|
||||
assert new.shape[2] == 2
|
||||
assert new.stored_labels[1]["dof"] == ["x"]
|
||||
|
||||
Reference in New Issue
Block a user