import torch import pytest from pina.label_tensor import LabelTensor #import pina data = torch.rand((20, 3)) labels_column = { 1: { "name": "space", "dof": ['x', 'y', 'z'] } } labels_row = { 0: { "name": "samples", "dof": range(20) } } labels_all = labels_column | labels_row @pytest.mark.parametrize("labels", [labels_column, labels_row, labels_all]) def test_constructor(labels): LabelTensor(data, labels) def test_wrong_constructor(): with pytest.raises(ValueError): LabelTensor(data, ['a', 'b']) @pytest.mark.parametrize("labels", [labels_column, labels_all]) @pytest.mark.parametrize("labels_te", ['z', ['z'], {'space': ['z']}]) def test_extract_column(labels, labels_te): tensor = LabelTensor(data, labels) new = tensor.extract(labels_te) assert new.ndim == tensor.ndim assert new.shape[1] == 1 assert new.shape[0] == 20 assert torch.all(torch.isclose(data[:, 2].reshape(-1, 1), new)) @pytest.mark.parametrize("labels", [labels_row, labels_all]) @pytest.mark.parametrize("labels_te", [{'samples': [2]}]) def test_extract_row(labels, labels_te): tensor = LabelTensor(data, labels) new = tensor.extract(labels_te) assert new.ndim == tensor.ndim assert new.shape[1] == 3 assert new.shape[0] == 1 assert torch.all(torch.isclose(data[2].reshape(1, -1), new)) @pytest.mark.parametrize("labels_te", [ {'samples': [2], 'space': ['z']}, {'space': 'z', 'samples': 2} ]) def test_extract_2D(labels_te): labels = labels_all tensor = LabelTensor(data, labels) new = tensor.extract(labels_te) assert new.ndim == tensor.ndim assert new.shape[1] == 1 assert new.shape[0] == 1 assert torch.all(torch.isclose(data[2,2].reshape(1, 1), new)) def test_extract_3D(): labels = labels_all data = torch.rand(20, 3, 4) labels = { 1: { "name": "space", "dof": ['x', 'y', 'z'] }, 2: { "name": "time", "dof": range(4) }, } labels_te = { 'space': ['x', 'z'], 'time': range(1, 4) } tensor = LabelTensor(data, labels) new = tensor.extract(labels_te) assert new.ndim == tensor.ndim assert new.shape[0] == 20 assert new.shape[1] == 2 assert new.shape[2] == 3 assert torch.all(torch.isclose( data[:, 0::2, 1:4].reshape(20, 2, 3), new )) def test_concatenation_3D(): data_1 = torch.rand(20, 3, 4) labels_1 = ['x', 'y', 'z', 'w'] lt1 = LabelTensor(data_1, labels_1) data_2 = torch.rand(50, 3, 4) labels_2 = ['x', 'y', 'z', 'w'] lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2]) assert lt_cat.shape == (70, 3, 4) assert lt_cat.labels[0]['dof'] == range(70) assert lt_cat.labels[1]['dof'] == range(3) assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w'] data_1 = torch.rand(20, 3, 4) labels_1 = ['x', 'y', 'z', 'w'] lt1 = LabelTensor(data_1, labels_1) data_2 = torch.rand(20, 2, 4) labels_2 = ['x', 'y', 'z', 'w'] lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2], dim=1) assert lt_cat.shape == (20, 5, 4) assert lt_cat.labels[0]['dof'] == range(20) assert lt_cat.labels[1]['dof'] == range(5) assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w'] data_1 = torch.rand(20, 3, 2) labels_1 = ['x', 'y'] lt1 = LabelTensor(data_1, labels_1) data_2 = torch.rand(20, 3, 3) labels_2 = ['z', 'w', 'a'] lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2], dim=2) assert lt_cat.shape == (20, 3, 5) assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a'] assert lt_cat.labels[0]['dof'] == range(20) assert lt_cat.labels[1]['dof'] == range(3) data_1 = torch.rand(20, 2, 4) labels_1 = ['x', 'y', 'z', 'w'] lt1 = LabelTensor(data_1, labels_1) data_2 = torch.rand(20, 3, 4) labels_2 = ['x', 'y', 'z', 'w'] lt2 = LabelTensor(data_2, labels_2) with pytest.raises(ValueError): LabelTensor.cat([lt1, lt2], dim=2) data_1 = torch.rand(20, 3, 2) labels_1 = ['x', 'y'] lt1 = LabelTensor(data_1, labels_1) data_2 = torch.rand(20, 3, 3) labels_2 = ['x', 'w', 'a'] lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2], dim=2) assert lt_cat.shape == (20, 3, 5) assert lt_cat.labels[2]['dof'] == range(5) assert lt_cat.labels[0]['dof'] == range(20) assert lt_cat.labels[1]['dof'] == range(3)