use LabelTensor, fix minor, docs

This commit is contained in:
Your Name
2022-03-29 18:05:26 +02:00
parent 12f4084d7f
commit 6b001c6c53
19 changed files with 370 additions and 322 deletions

View File

@@ -81,11 +81,10 @@ class LabelTensor(torch.Tensor):
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
"""
new_obj = LabelTensor([], self.labels)
tempTensor = super().to(*args, **kwargs)
new_obj.data = tempTensor.data
new_obj.requires_grad = tempTensor.requires_grad
return new_obj
tmp = super().to(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
def extract(self, label_to_extract):
"""
@@ -111,9 +110,23 @@ class LabelTensor(torch.Tensor):
except ValueError:
raise ValueError('`label_to_extract` not in the labels list')
extracted_tensor = LabelTensor(
self[:, indeces],
[self.labels[idx] for idx in indeces]
)
new_data = self[:, indeces].float()
new_labels = [self.labels[idx] for idx in indeces]
extracted_tensor = LabelTensor(new_data, new_labels)
return extracted_tensor
def append(self, lt):
"""
Return a copy of the merged tensors.
:param LabelTensor lt: the tensor to merge.
:return: the merged tensors
:rtype: LabelTensor
"""
if set(self.labels).intersection(lt.labels):
raise RuntimeError('The tensors to merge have common labels')
new_labels = self.labels + lt.labels
new_tensor = torch.cat((self, lt), dim=1)
return LabelTensor(new_tensor, new_labels)