From a88814170775c1785be05db9f2319c32ddae05ed Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 30 Sep 2024 12:23:15 +0200 Subject: [PATCH] Add concatenation test for LabelTensor --- tests/test_label_tensor.py | 60 +++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py index 7484a49..f87d3ab 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor.py @@ -87,4 +87,62 @@ def test_extract_3D(): assert torch.all(torch.isclose( data[:, 0::2, 1:4].reshape(20, 2, 3), new - )) \ No newline at end of file + )) + +def test_concatenation_3D(): + data_1 = torch.rand(20, 3, 4) + labels_1 = ['x', 'y', 'z', 'w'] + lt1 = LabelTensor(data_1, labels_1) + data_2 = torch.rand(50, 3, 4) + labels_2 = ['x', 'y', 'z', 'w'] + lt2 = LabelTensor(data_2, labels_2) + lt_cat = LabelTensor.cat([lt1, lt2]) + assert lt_cat.shape == (70, 3, 4) + assert lt_cat.labels[0]['dof'] == range(70) + assert lt_cat.labels[1]['dof'] == range(3) + assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w'] + + data_1 = torch.rand(20, 3, 4) + labels_1 = ['x', 'y', 'z', 'w'] + lt1 = LabelTensor(data_1, labels_1) + data_2 = torch.rand(20, 2, 4) + labels_2 = ['x', 'y', 'z', 'w'] + lt2 = LabelTensor(data_2, labels_2) + lt_cat = LabelTensor.cat([lt1, lt2], dim=1) + assert lt_cat.shape == (20, 5, 4) + assert lt_cat.labels[0]['dof'] == range(20) + assert lt_cat.labels[1]['dof'] == range(5) + assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w'] + + data_1 = torch.rand(20, 3, 2) + labels_1 = ['x', 'y'] + lt1 = LabelTensor(data_1, labels_1) + data_2 = torch.rand(20, 3, 3) + labels_2 = ['z', 'w', 'a'] + lt2 = LabelTensor(data_2, labels_2) + lt_cat = LabelTensor.cat([lt1, lt2], dim=2) + assert lt_cat.shape == (20, 3, 5) + assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a'] + assert lt_cat.labels[0]['dof'] == range(20) + assert lt_cat.labels[1]['dof'] == range(3) + + data_1 = torch.rand(20, 2, 4) + labels_1 = ['x', 'y', 'z', 'w'] + lt1 = LabelTensor(data_1, labels_1) + data_2 = torch.rand(20, 3, 4) + labels_2 = ['x', 'y', 'z', 'w'] + lt2 = LabelTensor(data_2, labels_2) + with pytest.raises(ValueError): + LabelTensor.cat([lt1, lt2], dim=2) + + data_1 = torch.rand(20, 3, 2) + labels_1 = ['x', 'y'] + lt1 = LabelTensor(data_1, labels_1) + data_2 = torch.rand(20, 3, 3) + labels_2 = ['x', 'w', 'a'] + lt2 = LabelTensor(data_2, labels_2) + lt_cat = LabelTensor.cat([lt1, lt2], dim=2) + assert lt_cat.shape == (20, 3, 5) + assert lt_cat.labels[2]['dof'] == range(5) + assert lt_cat.labels[0]['dof'] == range(20) + assert lt_cat.labels[1]['dof'] == range(3)