Add concatenation test for LabelTensor
This commit is contained in:
committed by
Nicola Demo
parent
16351f95ae
commit
a888141707
@@ -88,3 +88,61 @@ def test_extract_3D():
|
||||
data[:, 0::2, 1:4].reshape(20, 2, 3),
|
||||
new
|
||||
))
|
||||
|
||||
def test_concatenation_3D():
|
||||
data_1 = torch.rand(20, 3, 4)
|
||||
labels_1 = ['x', 'y', 'z', 'w']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(50, 3, 4)
|
||||
labels_2 = ['x', 'y', 'z', 'w']
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt_cat = LabelTensor.cat([lt1, lt2])
|
||||
assert lt_cat.shape == (70, 3, 4)
|
||||
assert lt_cat.labels[0]['dof'] == range(70)
|
||||
assert lt_cat.labels[1]['dof'] == range(3)
|
||||
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
|
||||
data_1 = torch.rand(20, 3, 4)
|
||||
labels_1 = ['x', 'y', 'z', 'w']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(20, 2, 4)
|
||||
labels_2 = ['x', 'y', 'z', 'w']
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt_cat = LabelTensor.cat([lt1, lt2], dim=1)
|
||||
assert lt_cat.shape == (20, 5, 4)
|
||||
assert lt_cat.labels[0]['dof'] == range(20)
|
||||
assert lt_cat.labels[1]['dof'] == range(5)
|
||||
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
|
||||
data_1 = torch.rand(20, 3, 2)
|
||||
labels_1 = ['x', 'y']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(20, 3, 3)
|
||||
labels_2 = ['z', 'w', 'a']
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt_cat = LabelTensor.cat([lt1, lt2], dim=2)
|
||||
assert lt_cat.shape == (20, 3, 5)
|
||||
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a']
|
||||
assert lt_cat.labels[0]['dof'] == range(20)
|
||||
assert lt_cat.labels[1]['dof'] == range(3)
|
||||
|
||||
data_1 = torch.rand(20, 2, 4)
|
||||
labels_1 = ['x', 'y', 'z', 'w']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(20, 3, 4)
|
||||
labels_2 = ['x', 'y', 'z', 'w']
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
with pytest.raises(ValueError):
|
||||
LabelTensor.cat([lt1, lt2], dim=2)
|
||||
|
||||
data_1 = torch.rand(20, 3, 2)
|
||||
labels_1 = ['x', 'y']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(20, 3, 3)
|
||||
labels_2 = ['x', 'w', 'a']
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt_cat = LabelTensor.cat([lt1, lt2], dim=2)
|
||||
assert lt_cat.shape == (20, 3, 5)
|
||||
assert lt_cat.labels[2]['dof'] == range(5)
|
||||
assert lt_cat.labels[0]['dof'] == range(20)
|
||||
assert lt_cat.labels[1]['dof'] == range(3)
|
||||
|
||||
Reference in New Issue
Block a user