Filippo0.2 (#361)

* Add summation and remove deepcopy (only for tensors) in LabelTensor class
* Update operators for compatibility with updated LabelTensor implementation
* Implement labels.setter in LabelTensor class
* Update LabelTensor

---------

Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2024-10-04 15:59:09 +02:00
committed by Nicola Demo
parent 1d3df2a127
commit fdb8f65143
4 changed files with 212 additions and 92 deletions

View File

@@ -17,12 +17,14 @@ labels_row = {
"dof": range(20)
}
}
labels_list = ['x', 'y', 'z']
labels_all = labels_column | labels_row
@pytest.mark.parametrize("labels", [labels_column, labels_row, labels_all])
@pytest.mark.parametrize("labels", [labels_column, labels_row, labels_all, labels_list])
def test_constructor(labels):
LabelTensor(data, labels)
def test_wrong_constructor():
with pytest.raises(ValueError):
LabelTensor(data, ['a', 'b'])
@@ -61,7 +63,6 @@ def test_extract_2D(labels_te):
assert torch.all(torch.isclose(data[2,2].reshape(1, 1), new))
def test_extract_3D():
labels = labels_all
data = torch.rand(20, 3, 4)
labels = {
1: {
@@ -80,6 +81,7 @@ def test_extract_3D():
tensor = LabelTensor(data, labels)
new = tensor.extract(labels_te)
tensor2 = LabelTensor(data, labels)
assert new.ndim == tensor.ndim
assert new.shape[0] == 20
assert new.shape[1] == 2
@@ -88,6 +90,10 @@ def test_extract_3D():
data[:, 0::2, 1:4].reshape(20, 2, 3),
new
))
assert tensor2.ndim == tensor.ndim
assert tensor2.shape == tensor.shape
assert tensor.labels == tensor2.labels
assert new.shape != tensor.shape
def test_concatenation_3D():
data_1 = torch.rand(20, 3, 4)
@@ -146,3 +152,51 @@ def test_concatenation_3D():
assert lt_cat.labels[2]['dof'] == range(5)
assert lt_cat.labels[0]['dof'] == range(20)
assert lt_cat.labels[1]['dof'] == range(3)
def test_summation():
lt1 = LabelTensor(torch.ones(20,3), labels_all)
lt2 = LabelTensor(torch.ones(30,3), ['x', 'y', 'z'])
with pytest.raises(RuntimeError):
LabelTensor.summation([lt1, lt2])
lt1 = LabelTensor(torch.ones(20,3), labels_all)
lt2 = LabelTensor(torch.ones(20,3), labels_all)
lt_sum = LabelTensor.summation([lt1, lt2])
assert lt_sum.ndim == lt_sum.ndim
assert lt_sum.shape[0] == 20
assert lt_sum.shape[1] == 3
assert lt_sum.labels == labels_all
assert torch.eq(lt_sum.tensor, torch.ones(20,3)*2).all()
lt1 = LabelTensor(torch.ones(20,3), labels_all)
lt2 = LabelTensor(torch.ones(20,3), labels_all)
lt3 = LabelTensor(torch.zeros(20, 3), labels_all)
lt_sum = LabelTensor.summation([lt1, lt2, lt3])
assert lt_sum.ndim == lt_sum.ndim
assert lt_sum.shape[0] == 20
assert lt_sum.shape[1] == 3
assert lt_sum.labels == labels_all
assert torch.eq(lt_sum.tensor, torch.ones(20,3)*2).all()
def test_append_3D():
data_1 = torch.rand(20, 3, 4)
labels_1 = ['x', 'y', 'z', 'w']
lt1 = LabelTensor(data_1, labels_1)
data_2 = torch.rand(50, 3, 4)
labels_2 = ['x', 'y', 'z', 'w']
lt2 = LabelTensor(data_2, labels_2)
lt1 = lt1.append(lt2)
assert lt1.shape == (70, 3, 4)
assert lt1.labels[0]['dof'] == range(70)
assert lt1.labels[1]['dof'] == range(3)
assert lt1.labels[2]['dof'] == ['x', 'y', 'z', 'w']
data_1 = torch.rand(20, 3, 2)
labels_1 = ['x', 'y']
lt1 = LabelTensor(data_1, labels_1)
data_2 = torch.rand(20, 3, 2)
labels_2 = ['z', 'w']
lt2 = LabelTensor(data_2, labels_2)
lt1 = lt1.append(lt2, mode='cross')
assert lt1.shape == (20, 3, 4)
assert lt1.labels[0]['dof'] == range(20)
assert lt1.labels[1]['dof'] == range(3)
assert lt1.labels[2]['dof'] == ['x', 'y', 'z', 'w']