Codacy correction
This commit is contained in:
committed by
Nicola Demo
parent
ea3d1924e7
commit
dd43c8304c
@@ -1,5 +1,5 @@
|
||||
""" Module for LabelTensor """
|
||||
from copy import copy
|
||||
from copy import copy, deepcopy
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@@ -10,17 +10,16 @@ def issubset(a, b):
|
||||
"""
|
||||
if isinstance(a, list) and isinstance(b, list):
|
||||
return set(a).issubset(set(b))
|
||||
elif isinstance(a, range) and isinstance(b, range):
|
||||
if isinstance(a, range) and isinstance(b, range):
|
||||
return a.start <= b.start and a.stop >= b.stop
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
class LabelTensor(torch.Tensor):
|
||||
"""Torch tensor with a label for any column."""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, x, labels, full=True, *args, **kwargs):
|
||||
def __new__(cls, x, labels, *args, **kwargs):
|
||||
if isinstance(x, LabelTensor):
|
||||
return x
|
||||
else:
|
||||
@@ -30,7 +29,7 @@ class LabelTensor(torch.Tensor):
|
||||
def tensor(self):
|
||||
return self.as_subclass(Tensor)
|
||||
|
||||
def __init__(self, x, labels, full=False):
|
||||
def __init__(self, x, labels, **kwargs):
|
||||
"""
|
||||
Construct a `LabelTensor` by passing a dict of the labels
|
||||
|
||||
@@ -42,14 +41,19 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
"""
|
||||
self.dim_names = None
|
||||
self.full = full
|
||||
self.full = kwargs.get('full', True)
|
||||
self.labels = labels
|
||||
|
||||
@classmethod
|
||||
def __internal_init__(cls, x, labels, dim_names ,full=False, *args, **kwargs):
|
||||
lt = cls.__new__(cls, x, labels, full, *args, **kwargs)
|
||||
def __internal_init__(cls,
|
||||
x,
|
||||
labels,
|
||||
dim_names,
|
||||
*args,
|
||||
**kwargs):
|
||||
lt = cls.__new__(cls, x, labels, *args, **kwargs)
|
||||
lt._labels = labels
|
||||
lt.full = full
|
||||
lt.full = kwargs.get('full', True)
|
||||
lt.dim_names = dim_names
|
||||
return lt
|
||||
|
||||
@@ -122,8 +126,12 @@ class LabelTensor(torch.Tensor):
|
||||
tensor_shape = self.shape
|
||||
|
||||
if hasattr(self, 'full') and self.full:
|
||||
labels = {i: labels[i] if i in labels else {'name': i} for i in
|
||||
labels.keys()}
|
||||
labels = {
|
||||
i: labels[i] if i in labels else {
|
||||
'name': i
|
||||
}
|
||||
for i in labels.keys()
|
||||
}
|
||||
for k, v in labels.items():
|
||||
# Init labels from str
|
||||
if isinstance(v, str):
|
||||
@@ -133,8 +141,8 @@ class LabelTensor(torch.Tensor):
|
||||
# Init from dict with only name key
|
||||
v['dof'] = range(tensor_shape[k])
|
||||
# Init from dict with both name and dof keys
|
||||
elif isinstance(v, dict) and sorted(list(v.keys())) == ['dof',
|
||||
'name']:
|
||||
elif isinstance(v, dict) and sorted(list(
|
||||
v.keys())) == ['dof', 'name']:
|
||||
dof_list = v['dof']
|
||||
dof_len = len(dof_list)
|
||||
if dof_len != len(set(dof_list)):
|
||||
@@ -143,7 +151,7 @@ class LabelTensor(torch.Tensor):
|
||||
raise ValueError(
|
||||
'Number of dof does not match tensor shape')
|
||||
else:
|
||||
ValueError('Illegal labels initialization')
|
||||
raise ValueError('Illegal labels initialization')
|
||||
# Perform update
|
||||
self._labels[k] = v
|
||||
|
||||
@@ -157,7 +165,11 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
# Create a dict with labels
|
||||
last_dim_labels = {
|
||||
self.ndim - 1: {'dof': labels, 'name': self.ndim - 1}}
|
||||
self.ndim - 1: {
|
||||
'dof': labels,
|
||||
'name': self.ndim - 1
|
||||
}
|
||||
}
|
||||
self._init_labels_from_dict(last_dim_labels)
|
||||
|
||||
def set_names(self):
|
||||
@@ -217,9 +229,10 @@ class LabelTensor(torch.Tensor):
|
||||
v = [v] if isinstance(v, (int, str)) else v
|
||||
|
||||
if not isinstance(v, range):
|
||||
extractor[idx_dim] = [dim_labels.index(i) for i in v] if len(
|
||||
v) > 1 else slice(dim_labels.index(v[0]),
|
||||
dim_labels.index(v[0]) + 1)
|
||||
extractor[idx_dim] = [dim_labels.index(i)
|
||||
for i in v] if len(v) > 1 else slice(
|
||||
dim_labels.index(v[0]),
|
||||
dim_labels.index(v[0]) + 1)
|
||||
else:
|
||||
extractor[idx_dim] = slice(v.start, v.stop)
|
||||
|
||||
@@ -263,10 +276,10 @@ class LabelTensor(torch.Tensor):
|
||||
new_tensor = torch.cat(tensors, dim=dim)
|
||||
|
||||
# Update labels
|
||||
labels = LabelTensor.__create_labels_cat(tensors,
|
||||
dim)
|
||||
labels = LabelTensor.__create_labels_cat(tensors, dim)
|
||||
|
||||
return LabelTensor.__internal_init__(new_tensor, labels, tensors[0].dim_names)
|
||||
return LabelTensor.__internal_init__(new_tensor, labels,
|
||||
tensors[0].dim_names)
|
||||
|
||||
@staticmethod
|
||||
def __create_labels_cat(tensors, dim):
|
||||
@@ -277,9 +290,10 @@ class LabelTensor(torch.Tensor):
|
||||
# check if:
|
||||
# - labels dict have same keys
|
||||
# - all labels are the same expect for dimension dim
|
||||
if not all(all(stored_labels[i][k] == stored_labels[0][k]
|
||||
for i in range(len(stored_labels)))
|
||||
for k in stored_labels[0].keys() if k != dim):
|
||||
if not all(
|
||||
all(stored_labels[i][k] == stored_labels[0][k]
|
||||
for i in range(len(stored_labels)))
|
||||
for k in stored_labels[0].keys() if k != dim):
|
||||
raise RuntimeError('tensors must have the same shape and dof')
|
||||
|
||||
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
|
||||
@@ -341,8 +355,12 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
last_dim_labels = ['+'.join(items) for items in zip(*last_dim_labels)]
|
||||
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
|
||||
labels.update({tensors[0].ndim - 1: {'dof': last_dim_labels,
|
||||
'name': tensors[0].name}})
|
||||
labels.update({
|
||||
tensors[0].ndim - 1: {
|
||||
'dof': last_dim_labels,
|
||||
'name': tensors[0].name
|
||||
}
|
||||
})
|
||||
return LabelTensor(data, labels)
|
||||
|
||||
def append(self, tensor, mode='std'):
|
||||
@@ -384,8 +402,9 @@ class LabelTensor(torch.Tensor):
|
||||
:param index:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(index, str) or (isinstance(index, (tuple, list)) and all(
|
||||
isinstance(a, str) for a in index)):
|
||||
if isinstance(index,
|
||||
str) or (isinstance(index, (tuple, list))
|
||||
and all(isinstance(a, str) for a in index)):
|
||||
return self.extract(index)
|
||||
|
||||
selected_lt = super().__getitem__(index)
|
||||
@@ -418,21 +437,31 @@ class LabelTensor(torch.Tensor):
|
||||
:return:
|
||||
"""
|
||||
old_dof = old_labels[dim]['dof']
|
||||
if not isinstance(index, (int, slice)) and len(index) == len(
|
||||
old_dof) and isinstance(old_dof, range):
|
||||
if not isinstance(
|
||||
index,
|
||||
(int, slice)) and len(index) == len(old_dof) and isinstance(
|
||||
old_dof, range):
|
||||
return
|
||||
if isinstance(index, torch.Tensor):
|
||||
index = index.nonzero(as_tuple=True)[
|
||||
0] if index.dtype == torch.bool else index.tolist()
|
||||
index = index.nonzero(
|
||||
as_tuple=True
|
||||
)[0] if index.dtype == torch.bool else index.tolist()
|
||||
if isinstance(index, list):
|
||||
to_update_labels.update({dim: {
|
||||
'dof': [old_dof[i] for i in index],
|
||||
'name': old_labels[dim]['name']}})
|
||||
to_update_labels.update({
|
||||
dim: {
|
||||
'dof': [old_dof[i] for i in index],
|
||||
'name': old_labels[dim]['name']
|
||||
}
|
||||
})
|
||||
else:
|
||||
to_update_labels.update({dim: {'dof': old_dof[index],
|
||||
'name': old_labels[dim]['name']}})
|
||||
to_update_labels.update(
|
||||
{dim: {
|
||||
'dof': old_dof[index],
|
||||
'name': old_labels[dim]['name']
|
||||
}})
|
||||
|
||||
def sort_labels(self, dim=None):
|
||||
|
||||
def arg_sort(lst):
|
||||
return sorted(range(len(lst)), key=lambda x: lst[x])
|
||||
|
||||
@@ -445,7 +474,6 @@ class LabelTensor(torch.Tensor):
|
||||
return self.__getitem__(indexer)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
from copy import deepcopy
|
||||
cls = self.__class__
|
||||
result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels))
|
||||
return result
|
||||
@@ -454,6 +482,8 @@ class LabelTensor(torch.Tensor):
|
||||
tensor = super().permute(*dims)
|
||||
stored_labels = self.stored_labels
|
||||
keys_list = list(*dims)
|
||||
labels = {keys_list.index(k): copy(stored_labels[k]) for k in
|
||||
stored_labels.keys()}
|
||||
labels = {
|
||||
keys_list.index(k): copy(stored_labels[k])
|
||||
for k in stored_labels.keys()
|
||||
}
|
||||
return LabelTensor.__internal_init__(tensor, labels, self.dim_names)
|
||||
|
||||
Reference in New Issue
Block a user