@@ -116,7 +116,10 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
new_data = self[:, indeces].float()
|
||||
new_labels = [self.labels[idx] for idx in indeces]
|
||||
extracted_tensor = LabelTensor(new_data, new_labels)
|
||||
|
||||
extracted_tensor = new_data.as_subclass(LabelTensor)
|
||||
extracted_tensor.labels = new_labels
|
||||
|
||||
|
||||
return extracted_tensor
|
||||
|
||||
@@ -150,9 +153,15 @@ class LabelTensor(torch.Tensor):
|
||||
tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
new_tensor = torch.cat((tensor1, tensor2), dim=1)
|
||||
return LabelTensor(new_tensor, new_labels)
|
||||
|
||||
new_tensor = new_tensor.as_subclass(LabelTensor)
|
||||
new_tensor.labels = new_labels
|
||||
return new_tensor
|
||||
|
||||
def __str__(self):
|
||||
s = f'labels({str(self.labels)})\n'
|
||||
if hasattr(self, 'labels'):
|
||||
s = f'labels({str(self.labels)})\n'
|
||||
else:
|
||||
s = 'no labels\n'
|
||||
s += super().__str__()
|
||||
return s
|
||||
|
||||
Reference in New Issue
Block a user