use LabelTensor, fix minor, docs
This commit is contained in:
@@ -81,11 +81,10 @@ class LabelTensor(torch.Tensor):
|
||||
Performs Tensor dtype and/or device conversion. For more details, see
|
||||
:meth:`torch.Tensor.to`.
|
||||
"""
|
||||
new_obj = LabelTensor([], self.labels)
|
||||
tempTensor = super().to(*args, **kwargs)
|
||||
new_obj.data = tempTensor.data
|
||||
new_obj.requires_grad = tempTensor.requires_grad
|
||||
return new_obj
|
||||
tmp = super().to(*args, **kwargs)
|
||||
new = self.__class__.clone(self)
|
||||
new.data = tmp.data
|
||||
return new
|
||||
|
||||
def extract(self, label_to_extract):
|
||||
"""
|
||||
@@ -111,9 +110,23 @@ class LabelTensor(torch.Tensor):
|
||||
except ValueError:
|
||||
raise ValueError('`label_to_extract` not in the labels list')
|
||||
|
||||
extracted_tensor = LabelTensor(
|
||||
self[:, indeces],
|
||||
[self.labels[idx] for idx in indeces]
|
||||
)
|
||||
new_data = self[:, indeces].float()
|
||||
new_labels = [self.labels[idx] for idx in indeces]
|
||||
extracted_tensor = LabelTensor(new_data, new_labels)
|
||||
|
||||
return extracted_tensor
|
||||
|
||||
def append(self, lt):
|
||||
"""
|
||||
Return a copy of the merged tensors.
|
||||
|
||||
:param LabelTensor lt: the tensor to merge.
|
||||
:return: the merged tensors
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if set(self.labels).intersection(lt.labels):
|
||||
raise RuntimeError('The tensors to merge have common labels')
|
||||
|
||||
new_labels = self.labels + lt.labels
|
||||
new_tensor = torch.cat((self, lt), dim=1)
|
||||
return LabelTensor(new_tensor, new_labels)
|
||||
|
||||
Reference in New Issue
Block a user