Add concatenation test for LabelTensor

This commit is contained in:
FilippoOlivo
2024-09-30 12:23:15 +02:00
committed by Nicola Demo
parent 16351f95ae
commit a888141707

View File

@@ -88,3 +88,61 @@ def test_extract_3D():
data[:, 0::2, 1:4].reshape(20, 2, 3), data[:, 0::2, 1:4].reshape(20, 2, 3),
new new
)) ))
def test_concatenation_3D():
data_1 = torch.rand(20, 3, 4)
labels_1 = ['x', 'y', 'z', 'w']
lt1 = LabelTensor(data_1, labels_1)
data_2 = torch.rand(50, 3, 4)
labels_2 = ['x', 'y', 'z', 'w']
lt2 = LabelTensor(data_2, labels_2)
lt_cat = LabelTensor.cat([lt1, lt2])
assert lt_cat.shape == (70, 3, 4)
assert lt_cat.labels[0]['dof'] == range(70)
assert lt_cat.labels[1]['dof'] == range(3)
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w']
data_1 = torch.rand(20, 3, 4)
labels_1 = ['x', 'y', 'z', 'w']
lt1 = LabelTensor(data_1, labels_1)
data_2 = torch.rand(20, 2, 4)
labels_2 = ['x', 'y', 'z', 'w']
lt2 = LabelTensor(data_2, labels_2)
lt_cat = LabelTensor.cat([lt1, lt2], dim=1)
assert lt_cat.shape == (20, 5, 4)
assert lt_cat.labels[0]['dof'] == range(20)
assert lt_cat.labels[1]['dof'] == range(5)
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w']
data_1 = torch.rand(20, 3, 2)
labels_1 = ['x', 'y']
lt1 = LabelTensor(data_1, labels_1)
data_2 = torch.rand(20, 3, 3)
labels_2 = ['z', 'w', 'a']
lt2 = LabelTensor(data_2, labels_2)
lt_cat = LabelTensor.cat([lt1, lt2], dim=2)
assert lt_cat.shape == (20, 3, 5)
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a']
assert lt_cat.labels[0]['dof'] == range(20)
assert lt_cat.labels[1]['dof'] == range(3)
data_1 = torch.rand(20, 2, 4)
labels_1 = ['x', 'y', 'z', 'w']
lt1 = LabelTensor(data_1, labels_1)
data_2 = torch.rand(20, 3, 4)
labels_2 = ['x', 'y', 'z', 'w']
lt2 = LabelTensor(data_2, labels_2)
with pytest.raises(ValueError):
LabelTensor.cat([lt1, lt2], dim=2)
data_1 = torch.rand(20, 3, 2)
labels_1 = ['x', 'y']
lt1 = LabelTensor(data_1, labels_1)
data_2 = torch.rand(20, 3, 3)
labels_2 = ['x', 'w', 'a']
lt2 = LabelTensor(data_2, labels_2)
lt_cat = LabelTensor.cat([lt1, lt2], dim=2)
assert lt_cat.shape == (20, 3, 5)
assert lt_cat.labels[2]['dof'] == range(5)
assert lt_cat.labels[0]['dof'] == range(20)
assert lt_cat.labels[1]['dof'] == range(3)