Files
PINA/pina/label_tensor.py
2021-11-29 15:29:00 +01:00

50 lines
1.2 KiB
Python

import torch
class LabelTensor():
def __init__(self, x, labels):
if len(labels) != x.shape[1]:
print(len(labels), x.shape[1])
raise ValueError
self.__labels = labels
self.tensor = x
def __getitem__(self, key):
if key in self.labels:
return self.tensor[:, self.labels.index(key)]
else:
return self.tensor.__getitem__(key)
def __repr__(self):
return self.tensor
def __str__(self):
return self.tensor, self.labels
@property
def labels(self):
return self.__labels
@staticmethod
def hstack(labeltensor_list):
concatenated_tensor = torch.cat([lt.tensor for lt in labeltensor_list], axis=1)
concatenated_label = sum([lt.labels for lt in labeltensor_list], [])
return LabelTensor(concatenated_tensor, concatenated_label)
if __name__ == "__main__":
import numpy as np
a = np.random.uniform(size=(20, 3))
a = np.random.uniform(size=(20, 3))
p = torch.from_numpy(a)
t = LabelTensor(p, labels=['u', 'p', 't'])
print(t)
print(t['u'])
t *= 2
print(t['u'])
print(t[:, 0])