minor fix
This commit is contained in:
@@ -107,10 +107,12 @@ class LabelTensor(torch.Tensor):
|
||||
raise TypeError(
|
||||
'`label_to_extract` should be a str, or a str iterator')
|
||||
|
||||
try:
|
||||
indeces = [self.labels.index(f) for f in label_to_extract]
|
||||
except ValueError:
|
||||
raise ValueError('`label_to_extract` not in the labels list')
|
||||
indeces = []
|
||||
for f in label_to_extract:
|
||||
try:
|
||||
indeces.append(self.labels.index(f))
|
||||
except ValueError:
|
||||
raise ValueError(f'`{f}` not in the labels list')
|
||||
|
||||
new_data = self[:, indeces].float()
|
||||
new_labels = [self.labels[idx] for idx in indeces]
|
||||
@@ -118,11 +120,12 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
return extracted_tensor
|
||||
|
||||
def append(self, lt):
|
||||
def append(self, lt, mode='std'):
|
||||
"""
|
||||
Return a copy of the merged tensors.
|
||||
|
||||
:param LabelTensor lt: the tensor to merge.
|
||||
:param str mode: {'std', 'first', 'cross'}
|
||||
:return: the merged tensors
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
@@ -130,7 +133,23 @@ class LabelTensor(torch.Tensor):
|
||||
raise RuntimeError('The tensors to merge have common labels')
|
||||
|
||||
new_labels = self.labels + lt.labels
|
||||
new_tensor = torch.cat((self, lt), dim=1)
|
||||
if mode == 'std':
|
||||
new_tensor = torch.cat((self, lt), dim=1)
|
||||
elif mode == 'first':
|
||||
raise NotImplementedError
|
||||
elif mode == 'cross':
|
||||
tensor1 = self
|
||||
tensor2 = lt
|
||||
n1 = tensor1.shape[0]
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
tensor1 = LabelTensor(
|
||||
tensor1.repeat(n2, 1),
|
||||
labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(
|
||||
tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
new_tensor = torch.cat((tensor1, tensor2), dim=1)
|
||||
return LabelTensor(new_tensor, new_labels)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
Reference in New Issue
Block a user