🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -1,4 +1,5 @@
""" Module for LabelTensor """
from typing import Any
import torch
from torch import Tensor
@@ -12,7 +13,7 @@ class LabelTensor(torch.Tensor):
return super().__new__(cls, x, *args, **kwargs)
def __init__(self, x, labels):
'''
"""
Construct a `LabelTensor` by passing a tensor and a list of column
labels. Such labels uniquely identify the columns of the tensor,
allowing for an easier manipulation.
@@ -64,7 +65,7 @@ class LabelTensor(torch.Tensor):
[0.9427, 0.5819],
[0.9518, 0.1025],
[0.8066, 0.9615]])
'''
"""
if x.ndim == 1:
x = x.reshape(-1, 1)
@@ -72,10 +73,12 @@ class LabelTensor(torch.Tensor):
labels = [labels]
if len(labels) != x.shape[-1]:
raise ValueError('the tensor has not the same number of columns of '
'the passed labels.')
raise ValueError(
"the tensor has not the same number of columns of "
"the passed labels."
)
self._labels = labels
@property
def labels(self):
"""Property decorator for labels
@@ -88,8 +91,10 @@ class LabelTensor(torch.Tensor):
@labels.setter
def labels(self, labels):
if len(labels) != self.shape[self.ndim - 1]: # small check
raise ValueError('The tensor has not the same number of columns of '
'the passed labels.')
raise ValueError(
"The tensor has not the same number of columns of "
"the passed labels."
)
self._labels = labels # assign the label
@@ -109,7 +114,7 @@ class LabelTensor(torch.Tensor):
all_labels = [label for lt in label_tensors for label in lt.labels]
if set(all_labels) != set(label_tensors[0].labels):
raise RuntimeError('The tensors to stack have different labels')
raise RuntimeError("The tensors to stack have different labels")
labels = label_tensors[0].labels
tensors = [lt.extract(labels) for lt in label_tensors]
@@ -123,7 +128,7 @@ class LabelTensor(torch.Tensor):
:return: A copy of the tensor.
:rtype: LabelTensor
"""
# # used before merging
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
@@ -184,14 +189,15 @@ class LabelTensor(torch.Tensor):
pass
else:
raise TypeError(
'`label_to_extract` should be a str, or a str iterator')
"`label_to_extract` should be a str, or a str iterator"
)
indeces = []
for f in label_to_extract:
try:
indeces.append(self.labels.index(f))
except ValueError:
raise ValueError(f'`{f}` not in the labels list')
raise ValueError(f"`{f}` not in the labels list")
new_data = super(Tensor, self.T).__getitem__(indeces).T
new_labels = [self.labels[idx] for idx in indeces]
@@ -203,17 +209,16 @@ class LabelTensor(torch.Tensor):
def detach(self):
detached = super().detach()
if hasattr(self, '_labels'):
if hasattr(self, "_labels"):
detached._labels = self._labels
return detached
def requires_grad_(self, mode = True):
def requires_grad_(self, mode=True):
lt = super().requires_grad_(mode)
lt.labels = self.labels
return lt
def append(self, lt, mode='std'):
def append(self, lt, mode="std"):
"""
Return a copy of the merged tensors.
@@ -223,22 +228,23 @@ class LabelTensor(torch.Tensor):
:rtype: LabelTensor
"""
if set(self.labels).intersection(lt.labels):
raise RuntimeError('The tensors to merge have common labels')
raise RuntimeError("The tensors to merge have common labels")
new_labels = self.labels + lt.labels
if mode == 'std':
if mode == "std":
new_tensor = torch.cat((self, lt), dim=1)
elif mode == 'first':
elif mode == "first":
raise NotImplementedError
elif mode == 'cross':
elif mode == "cross":
tensor1 = self
tensor2 = lt
n1 = tensor1.shape[0]
n2 = tensor2.shape[0]
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
labels=tensor2.labels)
tensor2 = LabelTensor(
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
)
new_tensor = torch.cat((tensor1, tensor2), dim=1)
new_tensor = new_tensor.as_subclass(LabelTensor)
@@ -250,34 +256,37 @@ class LabelTensor(torch.Tensor):
Return a copy of the selected tensor.
"""
if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
if isinstance(index, str) or (
isinstance(index, (tuple, list))
and all(isinstance(a, str) for a in index)
):
return self.extract(index)
selected_lt = super(Tensor, self).__getitem__(index)
try:
len_index = len(index)
except TypeError:
len_index = 1
if isinstance(index, int) or len_index == 1:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(1, -1)
if hasattr(self, 'labels'):
if hasattr(self, "labels"):
selected_lt.labels = self.labels
elif len_index == 2:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(-1, 1)
if hasattr(self, 'labels'):
if hasattr(self, "labels"):
if isinstance(index[1], list):
selected_lt.labels = [self.labels[i] for i in index[1]]
else:
selected_lt.labels = self.labels[index[1]]
else:
selected_lt.labels = self.labels
return selected_lt
@property
def tensor(self):
return self.as_subclass(Tensor)
@@ -286,9 +295,9 @@ class LabelTensor(torch.Tensor):
return super().__len__()
def __str__(self):
if hasattr(self, 'labels'):
s = f'labels({str(self.labels)})\n'
if hasattr(self, "labels"):
s = f"labels({str(self.labels)})\n"
else:
s = 'no labels\n'
s = "no labels\n"
s += super().__str__()
return s
return s