Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -1,4 +1,5 @@
""" Module for LabelTensor """
"""Module for LabelTensor"""
from copy import copy, deepcopy
import torch
from torch import Tensor
@@ -43,7 +44,7 @@ class LabelTensor(torch.Tensor):
:rtype: list
"""
if self.ndim - 1 in self._labels.keys():
return self._labels[self.ndim - 1]['dof']
return self._labels[self.ndim - 1]["dof"]
@property
def full_labels(self):
@@ -58,7 +59,7 @@ class LabelTensor(torch.Tensor):
if i in self._labels.keys():
to_return_dict[i] = self._labels[i]
else:
to_return_dict[i] = {'dof': range(shape_tensor[i]), 'name': i}
to_return_dict[i] = {"dof": range(shape_tensor[i]), "name": i}
return to_return_dict
@property
@@ -72,13 +73,13 @@ class LabelTensor(torch.Tensor):
@labels.setter
def labels(self, labels):
""""
""" "
Set properly the parameter _labels
:param labels: Labels to assign to the class variable _labels.
:type: labels: str | list(str) | dict
"""
if not hasattr(self, '_labels'):
if not hasattr(self, "_labels"):
self._labels = {}
if isinstance(labels, dict):
self._init_labels_from_dict(labels)
@@ -109,27 +110,34 @@ class LabelTensor(torch.Tensor):
if len(dof_list) != dim_size:
raise ValueError(
f"Number of dof ({len(dof_list)}) does not match "
f"tensor shape ({dim_size})")
f"tensor shape ({dim_size})"
)
for dim, label in labels.items():
if isinstance(label, dict):
if 'name' not in label:
label['name'] = dim
if 'dof' not in label:
label['dof'] = range(tensor_shape[dim])
if 'dof' in label and 'name' in label:
dof = label['dof']
if "name" not in label:
label["name"] = dim
if "dof" not in label:
label["dof"] = range(tensor_shape[dim])
if "dof" in label and "name" in label:
dof = label["dof"]
dof_list = dof if isinstance(dof, (list, range)) else [dof]
if not isinstance(dof_list, (list, range)):
raise ValueError(f"'dof' should be a list or range, not"
f" {type(dof_list)}")
raise ValueError(
f"'dof' should be a list or range, not"
f" {type(dof_list)}"
)
validate_dof(dof_list, tensor_shape[dim])
else:
raise ValueError("Labels dictionary must contain either "
" both 'name' and 'dof' keys")
raise ValueError(
"Labels dictionary must contain either "
" both 'name' and 'dof' keys"
)
else:
raise ValueError(f"Invalid label format for {dim}: Expected "
f"list or dictionary, got {type(label)}")
raise ValueError(
f"Invalid label format for {dim}: Expected "
f"list or dictionary, got {type(label)}"
)
# Assign validated label data to internal labels
self._labels[dim] = label
@@ -144,10 +152,7 @@ class LabelTensor(torch.Tensor):
"""
# Create a dict with labels
last_dim_labels = {
self.ndim - 1: {
'dof': labels,
'name': self.ndim - 1
}
self.ndim - 1: {"dof": labels, "name": self.ndim - 1}
}
self._init_labels_from_dict(last_dim_labels)
@@ -165,9 +170,14 @@ class LabelTensor(torch.Tensor):
def get_label_indices(dim_labels, labels_te):
if isinstance(labels_te, (int, str)):
labels_te = [labels_te]
return [dim_labels.index(label) for label in labels_te] if len(
labels_te) > 1 else slice(dim_labels.index(labels_te[0]),
dim_labels.index(labels_te[0]) + 1)
return (
[dim_labels.index(label) for label in labels_te]
if len(labels_te) > 1
else slice(
dim_labels.index(labels_te[0]),
dim_labels.index(labels_te[0]) + 1,
)
)
# Ensure labels_to_extract is a list or dict
if isinstance(labels_to_extract, (str, int)):
@@ -176,37 +186,39 @@ class LabelTensor(torch.Tensor):
labels = copy(self._labels)
# Get the dimension names and the respective dimension index
dim_names = {labels[dim]['name']: dim for dim in labels.keys()}
dim_names = {labels[dim]["name"]: dim for dim in labels.keys()}
ndim = super().ndim
tensor = self.tensor.as_subclass(torch.Tensor)
# Convert list/tuple to a dict for the last dimension if applicable
if isinstance(labels_to_extract, (list, tuple)):
last_dim = ndim - 1
dim_name = labels[last_dim]['name']
dim_name = labels[last_dim]["name"]
labels_to_extract = {dim_name: list(labels_to_extract)}
# Validate the labels_to_extract type
if not isinstance(labels_to_extract, dict):
raise ValueError(
"labels_to_extract must be a string, list, or dictionary.")
"labels_to_extract must be a string, list, or dictionary."
)
# Perform the extraction for each specified dimension
for dim_name, labels_te in labels_to_extract.items():
if dim_name not in dim_names:
raise ValueError(
f"Cannot extract labels for dimension '{dim_name}' as it is"
f" not present in the original labels.")
f" not present in the original labels."
)
idx_dim = dim_names[dim_name]
dim_labels = labels[idx_dim]['dof']
dim_labels = labels[idx_dim]["dof"]
indices = get_label_indices(dim_labels, labels_te)
extractor = [slice(None)] * ndim
extractor[idx_dim] = indices
tensor = tensor[tuple(extractor)]
labels[idx_dim] = {'dof': labels_te, 'name': dim_name}
labels[idx_dim] = {"dof": labels_te, "name": dim_name}
return LabelTensor(tensor, labels)
@@ -214,10 +226,10 @@ class LabelTensor(torch.Tensor):
"""
returns a string with the representation of the class
"""
s = ''
s = ""
for key, value in self._labels.items():
s += f"{key}: {value}\n"
s += '\n'
s += "\n"
s += self.tensor.__str__()
return s
@@ -249,11 +261,14 @@ class LabelTensor(torch.Tensor):
# concatenation dimension
for key in tensors_labels[0].keys():
if key != dim:
if any(tensors_labels[i][key] != tensors_labels[0][key]
for i in range(len(tensors_labels))):
if any(
tensors_labels[i][key] != tensors_labels[0][key]
for i in range(len(tensors_labels))
):
raise RuntimeError(
f"Tensors must have the same labels along all "
f"dimensions except {dim}.")
f"dimensions except {dim}."
)
# Copy and update the 'dof' for the concatenation dimension
cat_labels = {k: copy(v) for k, v in tensors_labels[0].items()}
@@ -261,9 +276,8 @@ class LabelTensor(torch.Tensor):
# Update labels if the concatenation dimension has labels
if dim in tensors[0].stored_labels:
if dim in cat_labels:
cat_dofs = [label[dim]['dof'] for label in
tensors_labels]
cat_labels[dim]['dof'] = sum(cat_dofs, [])
cat_dofs = [label[dim]["dof"] for label in tensors_labels]
cat_labels[dim]["dof"] = sum(cat_dofs, [])
else:
cat_labels = tensors[0].stored_labels
@@ -330,26 +344,30 @@ class LabelTensor(torch.Tensor):
:return: A copy of the tensor.
:rtype: LabelTensor
"""
out = LabelTensor(super().clone(*args, **kwargs),
deepcopy(self._labels))
out = LabelTensor(
super().clone(*args, **kwargs), deepcopy(self._labels)
)
return out
def append(self, tensor, mode='std'):
if mode == 'std':
def append(self, tensor, mode="std"):
if mode == "std":
# Call cat on last dimension
new_label_tensor = LabelTensor.cat([self, tensor],
dim=self.ndim - 1)
elif mode == 'cross':
new_label_tensor = LabelTensor.cat(
[self, tensor], dim=self.ndim - 1
)
elif mode == "cross":
# Crete tensor and call cat on last dimension
tensor1 = self
tensor2 = tensor
n1 = tensor1.shape[0]
n2 = tensor2.shape[0]
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
labels=tensor2.labels)
new_label_tensor = LabelTensor.cat([tensor1, tensor2],
dim=self.ndim - 1)
tensor2 = LabelTensor(
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
)
new_label_tensor = LabelTensor.cat(
[tensor1, tensor2], dim=self.ndim - 1
)
else:
raise ValueError('mode must be either "std" or "cross"')
return new_label_tensor
@@ -368,8 +386,9 @@ class LabelTensor(torch.Tensor):
return LabelTensor.cat(label_tensors, dim=0)
# This method is used to update labels
def _update_single_label(self, old_labels, to_update_labels, index, dim,
to_update_dim):
def _update_single_label(
self, old_labels, to_update_labels, index, dim, to_update_dim
):
"""
Update the labels of the tensor by selecting only the labels
:param old_labels: labels from which retrieve data
@@ -378,24 +397,29 @@ class LabelTensor(torch.Tensor):
:param dim: label index
:return:
"""
old_dof = old_labels[to_update_dim]['dof']
label_name = old_labels[dim]['name']
old_dof = old_labels[to_update_dim]["dof"]
label_name = old_labels[dim]["name"]
# Handle slicing
if isinstance(index, slice):
to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name}
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
# Handle single integer index
elif isinstance(index, int):
to_update_labels[dim] = {'dof': [old_dof[index]],
'name': label_name}
to_update_labels[dim] = {
"dof": [old_dof[index]],
"name": label_name,
}
# Handle lists or tensors
elif isinstance(index, (list, torch.Tensor)):
# Handle list of bools
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero().squeeze()
to_update_labels[dim] = {
'dof': [old_dof[i] for i in index] if isinstance(old_dof,
list) else index,
'name': label_name
"dof": (
[old_dof[i] for i in index]
if isinstance(old_dof, list)
else index
),
"name": label_name,
}
else:
raise NotImplementedError(
@@ -404,7 +428,7 @@ class LabelTensor(torch.Tensor):
)
def __getitem__(self, index):
""""
""" "
Override the __getitem__ method to handle the labels of the tensor.
Perform the __getitem__ operation on the tensor and update the labels.
@@ -416,8 +440,10 @@ class LabelTensor(torch.Tensor):
:raises IndexError: If an invalid index is accessed in the tensor.
"""
# Handle string index
if isinstance(index, str) or (isinstance(index, (tuple, list)) and all(
isinstance(i, str) for i in index)):
if isinstance(index, str) or (
isinstance(index, (tuple, list))
and all(isinstance(i, str) for i in index)
):
return self.extract(index)
# Retrieve selected tensor and labels
@@ -436,8 +462,9 @@ class LabelTensor(torch.Tensor):
if isinstance(idx, int):
selected_tensor = selected_tensor.unsqueeze(dim)
if idx != slice(None):
self._update_single_label(original_labels, updated_labels,
idx, dim, offset)
self._update_single_label(
original_labels, updated_labels, idx, dim, offset
)
else:
# Adjust label keys if dimension is reduced (case of integer
# index on a non-labeled dimension)
@@ -472,7 +499,7 @@ class LabelTensor(torch.Tensor):
dim = self.ndim - 1
if self.shape[dim] == 1:
return self
labels = self.stored_labels[dim]['dof']
labels = self.stored_labels[dim]["dof"]
sorted_index = arg_sort(labels)
# Define an indexer to sort the tensor along the specified dimension
indexer = [slice(None)] * self.ndim
@@ -509,10 +536,7 @@ class LabelTensor(torch.Tensor):
# Update lables
labels = self._labels
keys_list = list(*dims)
labels = {
keys_list.index(k): labels[k]
for k in labels.keys()
}
labels = {keys_list.index(k): labels[k] for k in labels.keys()}
# Assign labels to the new tensor
tensor._labels = labels
@@ -550,7 +574,7 @@ class LabelTensor(torch.Tensor):
"""
if not tensors:
raise ValueError('The tensors list must not be empty.')
raise ValueError("The tensors list must not be empty.")
if len(tensors) == 1:
return tensors[0]
@@ -565,13 +589,13 @@ class LabelTensor(torch.Tensor):
last_dim_labels.append(tensor.labels)
# Construct last dimension labels
last_dim_labels = ['+'.join(items) for items in zip(*last_dim_labels)]
last_dim_labels = ["+".join(items) for items in zip(*last_dim_labels)]
# Update the labels for the resulting tensor
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
labels[tensors[0].ndim - 1] = {
'dof': last_dim_labels,
'name': tensors[0].name
"dof": last_dim_labels,
"name": tensors[0].name,
}
return LabelTensor(data, labels)