fix old codes
This commit is contained in:
@@ -55,6 +55,8 @@ class LabelTensor(torch.Tensor):
|
||||
[0.9518, 0.1025],
|
||||
[0.8066, 0.9615]])
|
||||
'''
|
||||
if x.ndim == 1:
|
||||
x = x.reshape(-1, 1)
|
||||
|
||||
if isinstance(labels, str):
|
||||
labels = [labels]
|
||||
@@ -130,3 +132,8 @@ class LabelTensor(torch.Tensor):
|
||||
new_labels = self.labels + lt.labels
|
||||
new_tensor = torch.cat((self, lt), dim=1)
|
||||
return LabelTensor(new_tensor, new_labels)
|
||||
|
||||
def __str__(self):
|
||||
s = f'labels({str(self.labels)})\n'
|
||||
s += super().__str__()
|
||||
return s
|
||||
|
||||
Reference in New Issue
Block a user