CUDA option for labeltensor (#23)

* fix cuda device for labeltensor
This commit is contained in:
Nicola Demo
2022-09-08 17:31:49 +02:00
committed by GitHub
parent 9b2ab7be41
commit 06932196a8
5 changed files with 61 additions and 56 deletions

View File

@@ -116,7 +116,10 @@ class LabelTensor(torch.Tensor):
new_data = self[:, indeces].float()
new_labels = [self.labels[idx] for idx in indeces]
extracted_tensor = LabelTensor(new_data, new_labels)
extracted_tensor = new_data.as_subclass(LabelTensor)
extracted_tensor.labels = new_labels
return extracted_tensor
@@ -150,9 +153,15 @@ class LabelTensor(torch.Tensor):
tensor2.repeat_interleave(n1, dim=0),
labels=tensor2.labels)
new_tensor = torch.cat((tensor1, tensor2), dim=1)
return LabelTensor(new_tensor, new_labels)
new_tensor = new_tensor.as_subclass(LabelTensor)
new_tensor.labels = new_labels
return new_tensor
def __str__(self):
s = f'labels({str(self.labels)})\n'
if hasattr(self, 'labels'):
s = f'labels({str(self.labels)})\n'
else:
s = 'no labels\n'
s += super().__str__()
return s