fix old codes
This commit is contained in:
@@ -12,7 +12,7 @@ class myFeature(torch.nn.Module):
|
||||
super(myFeature, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.sin(torch.pi * x.extract('a'))
|
||||
return LabelTensor(torch.sin(torch.pi * x.extract('a')), 'sin(a)')
|
||||
|
||||
|
||||
data = torch.rand((20, 3))
|
||||
|
||||
@@ -72,9 +72,8 @@ def test_merge():
|
||||
|
||||
def test_merge():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_a = tensor.extract('a')
|
||||
tensor_b = tensor.extract('b')
|
||||
tensor_c = tensor.extract('c')
|
||||
|
||||
tensor_bb = tensor_b.append(tensor_b)
|
||||
assert torch.allclose(tensor_b, tensor.extract(['b', 'c']))
|
||||
tensor_bc = tensor_b.append(tensor_c)
|
||||
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
Reference in New Issue
Block a user