🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
""" Module for LabelTensor """
|
||||
|
||||
from typing import Any
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -12,7 +13,7 @@ class LabelTensor(torch.Tensor):
|
||||
return super().__new__(cls, x, *args, **kwargs)
|
||||
|
||||
def __init__(self, x, labels):
|
||||
'''
|
||||
"""
|
||||
Construct a `LabelTensor` by passing a tensor and a list of column
|
||||
labels. Such labels uniquely identify the columns of the tensor,
|
||||
allowing for an easier manipulation.
|
||||
@@ -64,7 +65,7 @@ class LabelTensor(torch.Tensor):
|
||||
[0.9427, 0.5819],
|
||||
[0.9518, 0.1025],
|
||||
[0.8066, 0.9615]])
|
||||
'''
|
||||
"""
|
||||
if x.ndim == 1:
|
||||
x = x.reshape(-1, 1)
|
||||
|
||||
@@ -72,10 +73,12 @@ class LabelTensor(torch.Tensor):
|
||||
labels = [labels]
|
||||
|
||||
if len(labels) != x.shape[-1]:
|
||||
raise ValueError('the tensor has not the same number of columns of '
|
||||
'the passed labels.')
|
||||
raise ValueError(
|
||||
"the tensor has not the same number of columns of "
|
||||
"the passed labels."
|
||||
)
|
||||
self._labels = labels
|
||||
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
"""Property decorator for labels
|
||||
@@ -88,8 +91,10 @@ class LabelTensor(torch.Tensor):
|
||||
@labels.setter
|
||||
def labels(self, labels):
|
||||
if len(labels) != self.shape[self.ndim - 1]: # small check
|
||||
raise ValueError('The tensor has not the same number of columns of '
|
||||
'the passed labels.')
|
||||
raise ValueError(
|
||||
"The tensor has not the same number of columns of "
|
||||
"the passed labels."
|
||||
)
|
||||
|
||||
self._labels = labels # assign the label
|
||||
|
||||
@@ -109,7 +114,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
all_labels = [label for lt in label_tensors for label in lt.labels]
|
||||
if set(all_labels) != set(label_tensors[0].labels):
|
||||
raise RuntimeError('The tensors to stack have different labels')
|
||||
raise RuntimeError("The tensors to stack have different labels")
|
||||
|
||||
labels = label_tensors[0].labels
|
||||
tensors = [lt.extract(labels) for lt in label_tensors]
|
||||
@@ -123,7 +128,7 @@ class LabelTensor(torch.Tensor):
|
||||
:return: A copy of the tensor.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# # used before merging
|
||||
# # used before merging
|
||||
# try:
|
||||
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
# except:
|
||||
@@ -184,14 +189,15 @@ class LabelTensor(torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(
|
||||
'`label_to_extract` should be a str, or a str iterator')
|
||||
"`label_to_extract` should be a str, or a str iterator"
|
||||
)
|
||||
|
||||
indeces = []
|
||||
for f in label_to_extract:
|
||||
try:
|
||||
indeces.append(self.labels.index(f))
|
||||
except ValueError:
|
||||
raise ValueError(f'`{f}` not in the labels list')
|
||||
raise ValueError(f"`{f}` not in the labels list")
|
||||
|
||||
new_data = super(Tensor, self.T).__getitem__(indeces).T
|
||||
new_labels = [self.labels[idx] for idx in indeces]
|
||||
@@ -203,17 +209,16 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
def detach(self):
|
||||
detached = super().detach()
|
||||
if hasattr(self, '_labels'):
|
||||
if hasattr(self, "_labels"):
|
||||
detached._labels = self._labels
|
||||
return detached
|
||||
|
||||
|
||||
def requires_grad_(self, mode = True):
|
||||
def requires_grad_(self, mode=True):
|
||||
lt = super().requires_grad_(mode)
|
||||
lt.labels = self.labels
|
||||
return lt
|
||||
|
||||
def append(self, lt, mode='std'):
|
||||
def append(self, lt, mode="std"):
|
||||
"""
|
||||
Return a copy of the merged tensors.
|
||||
|
||||
@@ -223,22 +228,23 @@ class LabelTensor(torch.Tensor):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if set(self.labels).intersection(lt.labels):
|
||||
raise RuntimeError('The tensors to merge have common labels')
|
||||
raise RuntimeError("The tensors to merge have common labels")
|
||||
|
||||
new_labels = self.labels + lt.labels
|
||||
if mode == 'std':
|
||||
if mode == "std":
|
||||
new_tensor = torch.cat((self, lt), dim=1)
|
||||
elif mode == 'first':
|
||||
elif mode == "first":
|
||||
raise NotImplementedError
|
||||
elif mode == 'cross':
|
||||
elif mode == "cross":
|
||||
tensor1 = self
|
||||
tensor2 = lt
|
||||
n1 = tensor1.shape[0]
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
tensor2 = LabelTensor(
|
||||
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
|
||||
)
|
||||
new_tensor = torch.cat((tensor1, tensor2), dim=1)
|
||||
|
||||
new_tensor = new_tensor.as_subclass(LabelTensor)
|
||||
@@ -250,34 +256,37 @@ class LabelTensor(torch.Tensor):
|
||||
Return a copy of the selected tensor.
|
||||
"""
|
||||
|
||||
if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
|
||||
if isinstance(index, str) or (
|
||||
isinstance(index, (tuple, list))
|
||||
and all(isinstance(a, str) for a in index)
|
||||
):
|
||||
return self.extract(index)
|
||||
|
||||
selected_lt = super(Tensor, self).__getitem__(index)
|
||||
|
||||
|
||||
try:
|
||||
len_index = len(index)
|
||||
except TypeError:
|
||||
len_index = 1
|
||||
|
||||
|
||||
if isinstance(index, int) or len_index == 1:
|
||||
if selected_lt.ndim == 1:
|
||||
selected_lt = selected_lt.reshape(1, -1)
|
||||
if hasattr(self, 'labels'):
|
||||
if hasattr(self, "labels"):
|
||||
selected_lt.labels = self.labels
|
||||
elif len_index == 2:
|
||||
if selected_lt.ndim == 1:
|
||||
selected_lt = selected_lt.reshape(-1, 1)
|
||||
if hasattr(self, 'labels'):
|
||||
if hasattr(self, "labels"):
|
||||
if isinstance(index[1], list):
|
||||
selected_lt.labels = [self.labels[i] for i in index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels[index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels
|
||||
|
||||
|
||||
return selected_lt
|
||||
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
return self.as_subclass(Tensor)
|
||||
@@ -286,9 +295,9 @@ class LabelTensor(torch.Tensor):
|
||||
return super().__len__()
|
||||
|
||||
def __str__(self):
|
||||
if hasattr(self, 'labels'):
|
||||
s = f'labels({str(self.labels)})\n'
|
||||
if hasattr(self, "labels"):
|
||||
s = f"labels({str(self.labels)})\n"
|
||||
else:
|
||||
s = 'no labels\n'
|
||||
s = "no labels\n"
|
||||
s += super().__str__()
|
||||
return s
|
||||
return s
|
||||
|
||||
Reference in New Issue
Block a user