Fix Codacy Warnings (#477)
--------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
e3790e049a
commit
4177bfbb50
@@ -17,9 +17,15 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
"""
|
||||
Give the tensor part of the LabelTensor.
|
||||
|
||||
:return: tensor part of the LabelTensor
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.as_subclass(Tensor)
|
||||
|
||||
def __init__(self, x, labels, **kwargs):
|
||||
def __init__(self, x, labels):
|
||||
"""
|
||||
Construct a `LabelTensor` by passing a dict of the labels
|
||||
|
||||
@@ -43,8 +49,9 @@ class LabelTensor(torch.Tensor):
|
||||
:return: labels of self
|
||||
:rtype: list
|
||||
"""
|
||||
if self.ndim - 1 in self._labels.keys():
|
||||
if self.ndim - 1 in self._labels:
|
||||
return self._labels[self.ndim - 1]["dof"]
|
||||
return None
|
||||
|
||||
@property
|
||||
def full_labels(self):
|
||||
@@ -55,11 +62,11 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
to_return_dict = {}
|
||||
shape_tensor = self.shape
|
||||
for i in range(len(shape_tensor)):
|
||||
if i in self._labels.keys():
|
||||
for i, value in enumerate(shape_tensor):
|
||||
if i in self._labels:
|
||||
to_return_dict[i] = self._labels[i]
|
||||
else:
|
||||
to_return_dict[i] = {"dof": range(shape_tensor[i]), "name": i}
|
||||
to_return_dict[i] = {"dof": range(value), "name": i}
|
||||
return to_return_dict
|
||||
|
||||
@property
|
||||
@@ -186,7 +193,7 @@ class LabelTensor(torch.Tensor):
|
||||
labels = copy(self._labels)
|
||||
|
||||
# Get the dimension names and the respective dimension index
|
||||
dim_names = {labels[dim]["name"]: dim for dim in labels.keys()}
|
||||
dim_names = {labels[dim]["name"]: dim for dim in labels}
|
||||
ndim = super().ndim
|
||||
tensor = self.tensor.as_subclass(torch.Tensor)
|
||||
|
||||
@@ -259,7 +266,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
# Check label consistency across tensors, excluding the
|
||||
# concatenation dimension
|
||||
for key in tensors_labels[0].keys():
|
||||
for key in tensors_labels[0]:
|
||||
if key != dim:
|
||||
if any(
|
||||
tensors_labels[i][key] != tensors_labels[0][key]
|
||||
@@ -325,6 +332,12 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""
|
||||
Give the dtype of the tensor.
|
||||
|
||||
:return: dtype of the tensor
|
||||
:rtype: torch.dtype
|
||||
"""
|
||||
return super().dtype
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
@@ -350,12 +363,31 @@ class LabelTensor(torch.Tensor):
|
||||
return out
|
||||
|
||||
def append(self, tensor, mode="std"):
|
||||
"""
|
||||
Appends a given tensor to the current tensor along the last dimension.
|
||||
|
||||
This method allows for two types of appending operations:
|
||||
1. **Standard append** ("std"): Concatenates the tensors along the
|
||||
last dimension.
|
||||
2. **Cross append** ("cross"): Repeats the current tensor and the new
|
||||
tensor in a cross-product manner, then concatenates them.
|
||||
|
||||
:param LabelTensor tensor: The tensor to append.
|
||||
:param mode: The append mode to use. Defaults to "std".
|
||||
:type mode: str, optional
|
||||
:return: The new tensor obtained by appending the input tensor
|
||||
(either 'std' or 'cross').
|
||||
:rtype: LabelTensor
|
||||
|
||||
:raises ValueError: If the mode is not "std" or "cross".
|
||||
"""
|
||||
if mode == "std":
|
||||
# Call cat on last dimension
|
||||
new_label_tensor = LabelTensor.cat(
|
||||
[self, tensor], dim=self.ndim - 1
|
||||
)
|
||||
elif mode == "cross":
|
||||
return new_label_tensor
|
||||
if mode == "cross":
|
||||
# Crete tensor and call cat on last dimension
|
||||
tensor1 = self
|
||||
tensor2 = tensor
|
||||
@@ -368,9 +400,8 @@ class LabelTensor(torch.Tensor):
|
||||
new_label_tensor = LabelTensor.cat(
|
||||
[tensor1, tensor2], dim=self.ndim - 1
|
||||
)
|
||||
else:
|
||||
raise ValueError('mode must be either "std" or "cross"')
|
||||
return new_label_tensor
|
||||
return new_label_tensor
|
||||
raise ValueError('mode must be either "std" or "cross"')
|
||||
|
||||
@staticmethod
|
||||
def vstack(label_tensors):
|
||||
@@ -461,7 +492,7 @@ class LabelTensor(torch.Tensor):
|
||||
# Update labels based on the index
|
||||
offset = 0
|
||||
for dim, idx in enumerate(index):
|
||||
if dim in self.stored_labels.keys():
|
||||
if dim in self.stored_labels:
|
||||
if isinstance(idx, int):
|
||||
selected_tensor = selected_tensor.unsqueeze(dim)
|
||||
if idx != slice(None):
|
||||
@@ -508,7 +539,7 @@ class LabelTensor(torch.Tensor):
|
||||
indexer = [slice(None)] * self.ndim
|
||||
# Assigned the sorted index to the specified dimension
|
||||
indexer[dim] = sorted_index
|
||||
return self.__getitem__(tuple(indexer))
|
||||
return self[tuple(indexer)]
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
@@ -539,7 +570,7 @@ class LabelTensor(torch.Tensor):
|
||||
# Update lables
|
||||
labels = self._labels
|
||||
keys_list = list(*dims)
|
||||
labels = {keys_list.index(k): labels[k] for k in labels.keys()}
|
||||
labels = {keys_list.index(k): v for k, v in labels.items()}
|
||||
|
||||
# Assign labels to the new tensor
|
||||
tensor._labels = labels
|
||||
|
||||
Reference in New Issue
Block a user