Codacy correction

This commit is contained in:
FilippoOlivo
2024-10-31 09:50:19 +01:00
committed by Nicola Demo
parent ea3d1924e7
commit dd43c8304c
23 changed files with 246 additions and 214 deletions

View File

@@ -1,5 +1,5 @@
""" Module for LabelTensor """
from copy import copy
from copy import copy, deepcopy
import torch
from torch import Tensor
@@ -10,17 +10,16 @@ def issubset(a, b):
"""
if isinstance(a, list) and isinstance(b, list):
return set(a).issubset(set(b))
elif isinstance(a, range) and isinstance(b, range):
if isinstance(a, range) and isinstance(b, range):
return a.start <= b.start and a.stop >= b.stop
else:
return False
return False
class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column."""
@staticmethod
def __new__(cls, x, labels, full=True, *args, **kwargs):
def __new__(cls, x, labels, *args, **kwargs):
if isinstance(x, LabelTensor):
return x
else:
@@ -30,7 +29,7 @@ class LabelTensor(torch.Tensor):
def tensor(self):
return self.as_subclass(Tensor)
def __init__(self, x, labels, full=False):
def __init__(self, x, labels, **kwargs):
"""
Construct a `LabelTensor` by passing a dict of the labels
@@ -42,14 +41,19 @@ class LabelTensor(torch.Tensor):
"""
self.dim_names = None
self.full = full
self.full = kwargs.get('full', True)
self.labels = labels
@classmethod
def __internal_init__(cls, x, labels, dim_names ,full=False, *args, **kwargs):
lt = cls.__new__(cls, x, labels, full, *args, **kwargs)
def __internal_init__(cls,
x,
labels,
dim_names,
*args,
**kwargs):
lt = cls.__new__(cls, x, labels, *args, **kwargs)
lt._labels = labels
lt.full = full
lt.full = kwargs.get('full', True)
lt.dim_names = dim_names
return lt
@@ -122,8 +126,12 @@ class LabelTensor(torch.Tensor):
tensor_shape = self.shape
if hasattr(self, 'full') and self.full:
labels = {i: labels[i] if i in labels else {'name': i} for i in
labels.keys()}
labels = {
i: labels[i] if i in labels else {
'name': i
}
for i in labels.keys()
}
for k, v in labels.items():
# Init labels from str
if isinstance(v, str):
@@ -133,8 +141,8 @@ class LabelTensor(torch.Tensor):
# Init from dict with only name key
v['dof'] = range(tensor_shape[k])
# Init from dict with both name and dof keys
elif isinstance(v, dict) and sorted(list(v.keys())) == ['dof',
'name']:
elif isinstance(v, dict) and sorted(list(
v.keys())) == ['dof', 'name']:
dof_list = v['dof']
dof_len = len(dof_list)
if dof_len != len(set(dof_list)):
@@ -143,7 +151,7 @@ class LabelTensor(torch.Tensor):
raise ValueError(
'Number of dof does not match tensor shape')
else:
ValueError('Illegal labels initialization')
raise ValueError('Illegal labels initialization')
# Perform update
self._labels[k] = v
@@ -157,7 +165,11 @@ class LabelTensor(torch.Tensor):
"""
# Create a dict with labels
last_dim_labels = {
self.ndim - 1: {'dof': labels, 'name': self.ndim - 1}}
self.ndim - 1: {
'dof': labels,
'name': self.ndim - 1
}
}
self._init_labels_from_dict(last_dim_labels)
def set_names(self):
@@ -217,9 +229,10 @@ class LabelTensor(torch.Tensor):
v = [v] if isinstance(v, (int, str)) else v
if not isinstance(v, range):
extractor[idx_dim] = [dim_labels.index(i) for i in v] if len(
v) > 1 else slice(dim_labels.index(v[0]),
dim_labels.index(v[0]) + 1)
extractor[idx_dim] = [dim_labels.index(i)
for i in v] if len(v) > 1 else slice(
dim_labels.index(v[0]),
dim_labels.index(v[0]) + 1)
else:
extractor[idx_dim] = slice(v.start, v.stop)
@@ -263,10 +276,10 @@ class LabelTensor(torch.Tensor):
new_tensor = torch.cat(tensors, dim=dim)
# Update labels
labels = LabelTensor.__create_labels_cat(tensors,
dim)
labels = LabelTensor.__create_labels_cat(tensors, dim)
return LabelTensor.__internal_init__(new_tensor, labels, tensors[0].dim_names)
return LabelTensor.__internal_init__(new_tensor, labels,
tensors[0].dim_names)
@staticmethod
def __create_labels_cat(tensors, dim):
@@ -277,9 +290,10 @@ class LabelTensor(torch.Tensor):
# check if:
# - labels dict have same keys
# - all labels are the same expect for dimension dim
if not all(all(stored_labels[i][k] == stored_labels[0][k]
for i in range(len(stored_labels)))
for k in stored_labels[0].keys() if k != dim):
if not all(
all(stored_labels[i][k] == stored_labels[0][k]
for i in range(len(stored_labels)))
for k in stored_labels[0].keys() if k != dim):
raise RuntimeError('tensors must have the same shape and dof')
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
@@ -341,8 +355,12 @@ class LabelTensor(torch.Tensor):
last_dim_labels = ['+'.join(items) for items in zip(*last_dim_labels)]
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
labels.update({tensors[0].ndim - 1: {'dof': last_dim_labels,
'name': tensors[0].name}})
labels.update({
tensors[0].ndim - 1: {
'dof': last_dim_labels,
'name': tensors[0].name
}
})
return LabelTensor(data, labels)
def append(self, tensor, mode='std'):
@@ -384,8 +402,9 @@ class LabelTensor(torch.Tensor):
:param index:
:return:
"""
if isinstance(index, str) or (isinstance(index, (tuple, list)) and all(
isinstance(a, str) for a in index)):
if isinstance(index,
str) or (isinstance(index, (tuple, list))
and all(isinstance(a, str) for a in index)):
return self.extract(index)
selected_lt = super().__getitem__(index)
@@ -418,21 +437,31 @@ class LabelTensor(torch.Tensor):
:return:
"""
old_dof = old_labels[dim]['dof']
if not isinstance(index, (int, slice)) and len(index) == len(
old_dof) and isinstance(old_dof, range):
if not isinstance(
index,
(int, slice)) and len(index) == len(old_dof) and isinstance(
old_dof, range):
return
if isinstance(index, torch.Tensor):
index = index.nonzero(as_tuple=True)[
0] if index.dtype == torch.bool else index.tolist()
index = index.nonzero(
as_tuple=True
)[0] if index.dtype == torch.bool else index.tolist()
if isinstance(index, list):
to_update_labels.update({dim: {
'dof': [old_dof[i] for i in index],
'name': old_labels[dim]['name']}})
to_update_labels.update({
dim: {
'dof': [old_dof[i] for i in index],
'name': old_labels[dim]['name']
}
})
else:
to_update_labels.update({dim: {'dof': old_dof[index],
'name': old_labels[dim]['name']}})
to_update_labels.update(
{dim: {
'dof': old_dof[index],
'name': old_labels[dim]['name']
}})
def sort_labels(self, dim=None):
def arg_sort(lst):
return sorted(range(len(lst)), key=lambda x: lst[x])
@@ -445,7 +474,6 @@ class LabelTensor(torch.Tensor):
return self.__getitem__(indexer)
def __deepcopy__(self, memo):
from copy import deepcopy
cls = self.__class__
result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels))
return result
@@ -454,6 +482,8 @@ class LabelTensor(torch.Tensor):
tensor = super().permute(*dims)
stored_labels = self.stored_labels
keys_list = list(*dims)
labels = {keys_list.index(k): copy(stored_labels[k]) for k in
stored_labels.keys()}
labels = {
keys_list.index(k): copy(stored_labels[k])
for k in stored_labels.keys()
}
return LabelTensor.__internal_init__(tensor, labels, self.dim_names)