Vectorial output

This commit is contained in:
Nicola Demo
2022-03-07 10:09:40 +01:00
parent 1812ddb8d9
commit 8a1f07c8ae
6 changed files with 71 additions and 7 deletions

View File

@@ -12,6 +12,9 @@ class LabelTensor():
self.tensor = x
def __getitem__(self, key):
if isinstance(key, (tuple, list)):
indeces = [self.labels.index(k) for k in key]
return LabelTensor(self.tensor[:, indeces], [self.labels[idx] for idx in indeces])
if key in self.labels:
return self.tensor[:, self.labels.index(key)]
else: