Fix __getitem__ in LabelTensor (#546)

* Fix LabelTensor
* Cleaning label_tensor.py

---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
Filippo Olivo
2025-04-14 11:41:59 +02:00
committed by FilippoOlivo
parent e250e3f5f7
commit a1f98764d3
2 changed files with 85 additions and 28 deletions

View File

@@ -278,3 +278,63 @@ def test_cat_bool(labels):
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
if isinstance(labels, dict):
assert selected.stored_labels[0]["dof"] == ["a", "b"]
def test_getitem_int():
data = torch.rand(20, 3)
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
lt = LabelTensor(data, labels)
new = lt[0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert torch.all(torch.isclose(data[0, 0], new))
data = torch.rand(20, 3, 2)
labels = {
1: {"name": 1, "dof": ["x", "y", "z"]},
2: {"name": 2, "dof": ["a", "b"]},
}
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, 0, 0], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a"]
new = lt[0, 0, :]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 2
assert torch.all(torch.isclose(data[0, 0, :], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a", "b"]
new = lt[0, :, 1]
assert new.ndim == 2
assert new.shape[0] == 3
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
assert new.stored_labels[1]["dof"] == ["b"]
labels.pop(2)
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert new.stored_labels[0]["dof"] == ["x"]
new = lt[:, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.stored_labels[1]["dof"] == ["x"]
new = lt[:, 0, :]
assert new.ndim == 3
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.shape[2] == 2
assert new.stored_labels[1]["dof"] == ["x"]