Vectorial output
This commit is contained in:
@@ -12,6 +12,9 @@ class LabelTensor():
|
||||
self.tensor = x
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, (tuple, list)):
|
||||
indeces = [self.labels.index(k) for k in key]
|
||||
return LabelTensor(self.tensor[:, indeces], [self.labels[idx] for idx in indeces])
|
||||
if key in self.labels:
|
||||
return self.tensor[:, self.labels.index(key)]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user