Add concatenation test for LabelTensor
This commit is contained in:
committed by
Nicola Demo
parent
16351f95ae
commit
a888141707
@@ -88,3 +88,61 @@ def test_extract_3D():
|
|||||||
data[:, 0::2, 1:4].reshape(20, 2, 3),
|
data[:, 0::2, 1:4].reshape(20, 2, 3),
|
||||||
new
|
new
|
||||||
))
|
))
|
||||||
|
|
||||||
|
def test_concatenation_3D():
|
||||||
|
data_1 = torch.rand(20, 3, 4)
|
||||||
|
labels_1 = ['x', 'y', 'z', 'w']
|
||||||
|
lt1 = LabelTensor(data_1, labels_1)
|
||||||
|
data_2 = torch.rand(50, 3, 4)
|
||||||
|
labels_2 = ['x', 'y', 'z', 'w']
|
||||||
|
lt2 = LabelTensor(data_2, labels_2)
|
||||||
|
lt_cat = LabelTensor.cat([lt1, lt2])
|
||||||
|
assert lt_cat.shape == (70, 3, 4)
|
||||||
|
assert lt_cat.labels[0]['dof'] == range(70)
|
||||||
|
assert lt_cat.labels[1]['dof'] == range(3)
|
||||||
|
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||||
|
|
||||||
|
data_1 = torch.rand(20, 3, 4)
|
||||||
|
labels_1 = ['x', 'y', 'z', 'w']
|
||||||
|
lt1 = LabelTensor(data_1, labels_1)
|
||||||
|
data_2 = torch.rand(20, 2, 4)
|
||||||
|
labels_2 = ['x', 'y', 'z', 'w']
|
||||||
|
lt2 = LabelTensor(data_2, labels_2)
|
||||||
|
lt_cat = LabelTensor.cat([lt1, lt2], dim=1)
|
||||||
|
assert lt_cat.shape == (20, 5, 4)
|
||||||
|
assert lt_cat.labels[0]['dof'] == range(20)
|
||||||
|
assert lt_cat.labels[1]['dof'] == range(5)
|
||||||
|
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||||
|
|
||||||
|
data_1 = torch.rand(20, 3, 2)
|
||||||
|
labels_1 = ['x', 'y']
|
||||||
|
lt1 = LabelTensor(data_1, labels_1)
|
||||||
|
data_2 = torch.rand(20, 3, 3)
|
||||||
|
labels_2 = ['z', 'w', 'a']
|
||||||
|
lt2 = LabelTensor(data_2, labels_2)
|
||||||
|
lt_cat = LabelTensor.cat([lt1, lt2], dim=2)
|
||||||
|
assert lt_cat.shape == (20, 3, 5)
|
||||||
|
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a']
|
||||||
|
assert lt_cat.labels[0]['dof'] == range(20)
|
||||||
|
assert lt_cat.labels[1]['dof'] == range(3)
|
||||||
|
|
||||||
|
data_1 = torch.rand(20, 2, 4)
|
||||||
|
labels_1 = ['x', 'y', 'z', 'w']
|
||||||
|
lt1 = LabelTensor(data_1, labels_1)
|
||||||
|
data_2 = torch.rand(20, 3, 4)
|
||||||
|
labels_2 = ['x', 'y', 'z', 'w']
|
||||||
|
lt2 = LabelTensor(data_2, labels_2)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LabelTensor.cat([lt1, lt2], dim=2)
|
||||||
|
|
||||||
|
data_1 = torch.rand(20, 3, 2)
|
||||||
|
labels_1 = ['x', 'y']
|
||||||
|
lt1 = LabelTensor(data_1, labels_1)
|
||||||
|
data_2 = torch.rand(20, 3, 3)
|
||||||
|
labels_2 = ['x', 'w', 'a']
|
||||||
|
lt2 = LabelTensor(data_2, labels_2)
|
||||||
|
lt_cat = LabelTensor.cat([lt1, lt2], dim=2)
|
||||||
|
assert lt_cat.shape == (20, 3, 5)
|
||||||
|
assert lt_cat.labels[2]['dof'] == range(5)
|
||||||
|
assert lt_cat.labels[0]['dof'] == range(20)
|
||||||
|
assert lt_cat.labels[1]['dof'] == range(3)
|
||||||
|
|||||||
Reference in New Issue
Block a user