minor fix

This commit is contained in:
Your Name
2022-07-20 17:23:53 +02:00
committed by Nicola Demo
parent 75a81af99c
commit a05adea4e3
10 changed files with 231 additions and 203 deletions

View File

@@ -107,10 +107,12 @@ class LabelTensor(torch.Tensor):
raise TypeError(
'`label_to_extract` should be a str, or a str iterator')
try:
indeces = [self.labels.index(f) for f in label_to_extract]
except ValueError:
raise ValueError('`label_to_extract` not in the labels list')
indeces = []
for f in label_to_extract:
try:
indeces.append(self.labels.index(f))
except ValueError:
raise ValueError(f'`{f}` not in the labels list')
new_data = self[:, indeces].float()
new_labels = [self.labels[idx] for idx in indeces]
@@ -118,11 +120,12 @@ class LabelTensor(torch.Tensor):
return extracted_tensor
def append(self, lt):
def append(self, lt, mode='std'):
"""
Return a copy of the merged tensors.
:param LabelTensor lt: the tensor to merge.
:param str mode: {'std', 'first', 'cross'}
:return: the merged tensors
:rtype: LabelTensor
"""
@@ -130,7 +133,23 @@ class LabelTensor(torch.Tensor):
raise RuntimeError('The tensors to merge have common labels')
new_labels = self.labels + lt.labels
new_tensor = torch.cat((self, lt), dim=1)
if mode == 'std':
new_tensor = torch.cat((self, lt), dim=1)
elif mode == 'first':
raise NotImplementedError
elif mode == 'cross':
tensor1 = self
tensor2 = lt
n1 = tensor1.shape[0]
n2 = tensor2.shape[0]
tensor1 = LabelTensor(
tensor1.repeat(n2, 1),
labels=tensor1.labels)
tensor2 = LabelTensor(
tensor2.repeat_interleave(n1, dim=0),
labels=tensor2.labels)
new_tensor = torch.cat((tensor1, tensor2), dim=1)
return LabelTensor(new_tensor, new_labels)
def __str__(self):