fix old codes

This commit is contained in:
Your Name
2022-07-11 10:58:15 +02:00
parent 088649e042
commit f526a26050
19 changed files with 385 additions and 457 deletions

View File

@@ -55,6 +55,8 @@ class LabelTensor(torch.Tensor):
[0.9518, 0.1025],
[0.8066, 0.9615]])
'''
if x.ndim == 1:
x = x.reshape(-1, 1)
if isinstance(labels, str):
labels = [labels]
@@ -130,3 +132,8 @@ class LabelTensor(torch.Tensor):
new_labels = self.labels + lt.labels
new_tensor = torch.cat((self, lt), dim=1)
return LabelTensor(new_tensor, new_labels)
def __str__(self):
s = f'labels({str(self.labels)})\n'
s += super().__str__()
return s