fix old codes

This commit is contained in:
Your Name
2022-07-11 10:58:15 +02:00
parent 088649e042
commit f526a26050
19 changed files with 385 additions and 457 deletions

View File

@@ -12,7 +12,7 @@ class myFeature(torch.nn.Module):
super(myFeature, self).__init__()
def forward(self, x):
return torch.sin(torch.pi * x.extract('a'))
return LabelTensor(torch.sin(torch.pi * x.extract('a')), 'sin(a)')
data = torch.rand((20, 3))

View File

@@ -72,9 +72,8 @@ def test_merge():
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bb = tensor_b.append(tensor_b)
assert torch.allclose(tensor_b, tensor.extract(['b', 'c']))
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))