Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,5 @@
|
||||
""" Module for LabelTensor """
|
||||
"""Module for LabelTensor"""
|
||||
|
||||
from copy import copy, deepcopy
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -43,7 +44,7 @@ class LabelTensor(torch.Tensor):
|
||||
:rtype: list
|
||||
"""
|
||||
if self.ndim - 1 in self._labels.keys():
|
||||
return self._labels[self.ndim - 1]['dof']
|
||||
return self._labels[self.ndim - 1]["dof"]
|
||||
|
||||
@property
|
||||
def full_labels(self):
|
||||
@@ -58,7 +59,7 @@ class LabelTensor(torch.Tensor):
|
||||
if i in self._labels.keys():
|
||||
to_return_dict[i] = self._labels[i]
|
||||
else:
|
||||
to_return_dict[i] = {'dof': range(shape_tensor[i]), 'name': i}
|
||||
to_return_dict[i] = {"dof": range(shape_tensor[i]), "name": i}
|
||||
return to_return_dict
|
||||
|
||||
@property
|
||||
@@ -72,13 +73,13 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
@labels.setter
|
||||
def labels(self, labels):
|
||||
""""
|
||||
""" "
|
||||
Set properly the parameter _labels
|
||||
|
||||
:param labels: Labels to assign to the class variable _labels.
|
||||
:type: labels: str | list(str) | dict
|
||||
"""
|
||||
if not hasattr(self, '_labels'):
|
||||
if not hasattr(self, "_labels"):
|
||||
self._labels = {}
|
||||
if isinstance(labels, dict):
|
||||
self._init_labels_from_dict(labels)
|
||||
@@ -109,27 +110,34 @@ class LabelTensor(torch.Tensor):
|
||||
if len(dof_list) != dim_size:
|
||||
raise ValueError(
|
||||
f"Number of dof ({len(dof_list)}) does not match "
|
||||
f"tensor shape ({dim_size})")
|
||||
f"tensor shape ({dim_size})"
|
||||
)
|
||||
|
||||
for dim, label in labels.items():
|
||||
if isinstance(label, dict):
|
||||
if 'name' not in label:
|
||||
label['name'] = dim
|
||||
if 'dof' not in label:
|
||||
label['dof'] = range(tensor_shape[dim])
|
||||
if 'dof' in label and 'name' in label:
|
||||
dof = label['dof']
|
||||
if "name" not in label:
|
||||
label["name"] = dim
|
||||
if "dof" not in label:
|
||||
label["dof"] = range(tensor_shape[dim])
|
||||
if "dof" in label and "name" in label:
|
||||
dof = label["dof"]
|
||||
dof_list = dof if isinstance(dof, (list, range)) else [dof]
|
||||
if not isinstance(dof_list, (list, range)):
|
||||
raise ValueError(f"'dof' should be a list or range, not"
|
||||
f" {type(dof_list)}")
|
||||
raise ValueError(
|
||||
f"'dof' should be a list or range, not"
|
||||
f" {type(dof_list)}"
|
||||
)
|
||||
validate_dof(dof_list, tensor_shape[dim])
|
||||
else:
|
||||
raise ValueError("Labels dictionary must contain either "
|
||||
" both 'name' and 'dof' keys")
|
||||
raise ValueError(
|
||||
"Labels dictionary must contain either "
|
||||
" both 'name' and 'dof' keys"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid label format for {dim}: Expected "
|
||||
f"list or dictionary, got {type(label)}")
|
||||
raise ValueError(
|
||||
f"Invalid label format for {dim}: Expected "
|
||||
f"list or dictionary, got {type(label)}"
|
||||
)
|
||||
|
||||
# Assign validated label data to internal labels
|
||||
self._labels[dim] = label
|
||||
@@ -144,10 +152,7 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
# Create a dict with labels
|
||||
last_dim_labels = {
|
||||
self.ndim - 1: {
|
||||
'dof': labels,
|
||||
'name': self.ndim - 1
|
||||
}
|
||||
self.ndim - 1: {"dof": labels, "name": self.ndim - 1}
|
||||
}
|
||||
self._init_labels_from_dict(last_dim_labels)
|
||||
|
||||
@@ -165,9 +170,14 @@ class LabelTensor(torch.Tensor):
|
||||
def get_label_indices(dim_labels, labels_te):
|
||||
if isinstance(labels_te, (int, str)):
|
||||
labels_te = [labels_te]
|
||||
return [dim_labels.index(label) for label in labels_te] if len(
|
||||
labels_te) > 1 else slice(dim_labels.index(labels_te[0]),
|
||||
dim_labels.index(labels_te[0]) + 1)
|
||||
return (
|
||||
[dim_labels.index(label) for label in labels_te]
|
||||
if len(labels_te) > 1
|
||||
else slice(
|
||||
dim_labels.index(labels_te[0]),
|
||||
dim_labels.index(labels_te[0]) + 1,
|
||||
)
|
||||
)
|
||||
|
||||
# Ensure labels_to_extract is a list or dict
|
||||
if isinstance(labels_to_extract, (str, int)):
|
||||
@@ -176,37 +186,39 @@ class LabelTensor(torch.Tensor):
|
||||
labels = copy(self._labels)
|
||||
|
||||
# Get the dimension names and the respective dimension index
|
||||
dim_names = {labels[dim]['name']: dim for dim in labels.keys()}
|
||||
dim_names = {labels[dim]["name"]: dim for dim in labels.keys()}
|
||||
ndim = super().ndim
|
||||
tensor = self.tensor.as_subclass(torch.Tensor)
|
||||
|
||||
# Convert list/tuple to a dict for the last dimension if applicable
|
||||
if isinstance(labels_to_extract, (list, tuple)):
|
||||
last_dim = ndim - 1
|
||||
dim_name = labels[last_dim]['name']
|
||||
dim_name = labels[last_dim]["name"]
|
||||
labels_to_extract = {dim_name: list(labels_to_extract)}
|
||||
|
||||
# Validate the labels_to_extract type
|
||||
if not isinstance(labels_to_extract, dict):
|
||||
raise ValueError(
|
||||
"labels_to_extract must be a string, list, or dictionary.")
|
||||
"labels_to_extract must be a string, list, or dictionary."
|
||||
)
|
||||
|
||||
# Perform the extraction for each specified dimension
|
||||
for dim_name, labels_te in labels_to_extract.items():
|
||||
if dim_name not in dim_names:
|
||||
raise ValueError(
|
||||
f"Cannot extract labels for dimension '{dim_name}' as it is"
|
||||
f" not present in the original labels.")
|
||||
f" not present in the original labels."
|
||||
)
|
||||
|
||||
idx_dim = dim_names[dim_name]
|
||||
dim_labels = labels[idx_dim]['dof']
|
||||
dim_labels = labels[idx_dim]["dof"]
|
||||
indices = get_label_indices(dim_labels, labels_te)
|
||||
|
||||
extractor = [slice(None)] * ndim
|
||||
extractor[idx_dim] = indices
|
||||
tensor = tensor[tuple(extractor)]
|
||||
|
||||
labels[idx_dim] = {'dof': labels_te, 'name': dim_name}
|
||||
labels[idx_dim] = {"dof": labels_te, "name": dim_name}
|
||||
|
||||
return LabelTensor(tensor, labels)
|
||||
|
||||
@@ -214,10 +226,10 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
returns a string with the representation of the class
|
||||
"""
|
||||
s = ''
|
||||
s = ""
|
||||
for key, value in self._labels.items():
|
||||
s += f"{key}: {value}\n"
|
||||
s += '\n'
|
||||
s += "\n"
|
||||
s += self.tensor.__str__()
|
||||
return s
|
||||
|
||||
@@ -249,11 +261,14 @@ class LabelTensor(torch.Tensor):
|
||||
# concatenation dimension
|
||||
for key in tensors_labels[0].keys():
|
||||
if key != dim:
|
||||
if any(tensors_labels[i][key] != tensors_labels[0][key]
|
||||
for i in range(len(tensors_labels))):
|
||||
if any(
|
||||
tensors_labels[i][key] != tensors_labels[0][key]
|
||||
for i in range(len(tensors_labels))
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Tensors must have the same labels along all "
|
||||
f"dimensions except {dim}.")
|
||||
f"dimensions except {dim}."
|
||||
)
|
||||
|
||||
# Copy and update the 'dof' for the concatenation dimension
|
||||
cat_labels = {k: copy(v) for k, v in tensors_labels[0].items()}
|
||||
@@ -261,9 +276,8 @@ class LabelTensor(torch.Tensor):
|
||||
# Update labels if the concatenation dimension has labels
|
||||
if dim in tensors[0].stored_labels:
|
||||
if dim in cat_labels:
|
||||
cat_dofs = [label[dim]['dof'] for label in
|
||||
tensors_labels]
|
||||
cat_labels[dim]['dof'] = sum(cat_dofs, [])
|
||||
cat_dofs = [label[dim]["dof"] for label in tensors_labels]
|
||||
cat_labels[dim]["dof"] = sum(cat_dofs, [])
|
||||
else:
|
||||
cat_labels = tensors[0].stored_labels
|
||||
|
||||
@@ -330,26 +344,30 @@ class LabelTensor(torch.Tensor):
|
||||
:return: A copy of the tensor.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
out = LabelTensor(super().clone(*args, **kwargs),
|
||||
deepcopy(self._labels))
|
||||
out = LabelTensor(
|
||||
super().clone(*args, **kwargs), deepcopy(self._labels)
|
||||
)
|
||||
return out
|
||||
|
||||
def append(self, tensor, mode='std'):
|
||||
if mode == 'std':
|
||||
def append(self, tensor, mode="std"):
|
||||
if mode == "std":
|
||||
# Call cat on last dimension
|
||||
new_label_tensor = LabelTensor.cat([self, tensor],
|
||||
dim=self.ndim - 1)
|
||||
elif mode == 'cross':
|
||||
new_label_tensor = LabelTensor.cat(
|
||||
[self, tensor], dim=self.ndim - 1
|
||||
)
|
||||
elif mode == "cross":
|
||||
# Crete tensor and call cat on last dimension
|
||||
tensor1 = self
|
||||
tensor2 = tensor
|
||||
n1 = tensor1.shape[0]
|
||||
n2 = tensor2.shape[0]
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
new_label_tensor = LabelTensor.cat([tensor1, tensor2],
|
||||
dim=self.ndim - 1)
|
||||
tensor2 = LabelTensor(
|
||||
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
|
||||
)
|
||||
new_label_tensor = LabelTensor.cat(
|
||||
[tensor1, tensor2], dim=self.ndim - 1
|
||||
)
|
||||
else:
|
||||
raise ValueError('mode must be either "std" or "cross"')
|
||||
return new_label_tensor
|
||||
@@ -368,8 +386,9 @@ class LabelTensor(torch.Tensor):
|
||||
return LabelTensor.cat(label_tensors, dim=0)
|
||||
|
||||
# This method is used to update labels
|
||||
def _update_single_label(self, old_labels, to_update_labels, index, dim,
|
||||
to_update_dim):
|
||||
def _update_single_label(
|
||||
self, old_labels, to_update_labels, index, dim, to_update_dim
|
||||
):
|
||||
"""
|
||||
Update the labels of the tensor by selecting only the labels
|
||||
:param old_labels: labels from which retrieve data
|
||||
@@ -378,24 +397,29 @@ class LabelTensor(torch.Tensor):
|
||||
:param dim: label index
|
||||
:return:
|
||||
"""
|
||||
old_dof = old_labels[to_update_dim]['dof']
|
||||
label_name = old_labels[dim]['name']
|
||||
old_dof = old_labels[to_update_dim]["dof"]
|
||||
label_name = old_labels[dim]["name"]
|
||||
# Handle slicing
|
||||
if isinstance(index, slice):
|
||||
to_update_labels[dim] = {'dof': old_dof[index], 'name': label_name}
|
||||
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
|
||||
# Handle single integer index
|
||||
elif isinstance(index, int):
|
||||
to_update_labels[dim] = {'dof': [old_dof[index]],
|
||||
'name': label_name}
|
||||
to_update_labels[dim] = {
|
||||
"dof": [old_dof[index]],
|
||||
"name": label_name,
|
||||
}
|
||||
# Handle lists or tensors
|
||||
elif isinstance(index, (list, torch.Tensor)):
|
||||
# Handle list of bools
|
||||
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
|
||||
index = index.nonzero().squeeze()
|
||||
to_update_labels[dim] = {
|
||||
'dof': [old_dof[i] for i in index] if isinstance(old_dof,
|
||||
list) else index,
|
||||
'name': label_name
|
||||
"dof": (
|
||||
[old_dof[i] for i in index]
|
||||
if isinstance(old_dof, list)
|
||||
else index
|
||||
),
|
||||
"name": label_name,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -404,7 +428,7 @@ class LabelTensor(torch.Tensor):
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
""""
|
||||
""" "
|
||||
Override the __getitem__ method to handle the labels of the tensor.
|
||||
Perform the __getitem__ operation on the tensor and update the labels.
|
||||
|
||||
@@ -416,8 +440,10 @@ class LabelTensor(torch.Tensor):
|
||||
:raises IndexError: If an invalid index is accessed in the tensor.
|
||||
"""
|
||||
# Handle string index
|
||||
if isinstance(index, str) or (isinstance(index, (tuple, list)) and all(
|
||||
isinstance(i, str) for i in index)):
|
||||
if isinstance(index, str) or (
|
||||
isinstance(index, (tuple, list))
|
||||
and all(isinstance(i, str) for i in index)
|
||||
):
|
||||
return self.extract(index)
|
||||
|
||||
# Retrieve selected tensor and labels
|
||||
@@ -436,8 +462,9 @@ class LabelTensor(torch.Tensor):
|
||||
if isinstance(idx, int):
|
||||
selected_tensor = selected_tensor.unsqueeze(dim)
|
||||
if idx != slice(None):
|
||||
self._update_single_label(original_labels, updated_labels,
|
||||
idx, dim, offset)
|
||||
self._update_single_label(
|
||||
original_labels, updated_labels, idx, dim, offset
|
||||
)
|
||||
else:
|
||||
# Adjust label keys if dimension is reduced (case of integer
|
||||
# index on a non-labeled dimension)
|
||||
@@ -472,7 +499,7 @@ class LabelTensor(torch.Tensor):
|
||||
dim = self.ndim - 1
|
||||
if self.shape[dim] == 1:
|
||||
return self
|
||||
labels = self.stored_labels[dim]['dof']
|
||||
labels = self.stored_labels[dim]["dof"]
|
||||
sorted_index = arg_sort(labels)
|
||||
# Define an indexer to sort the tensor along the specified dimension
|
||||
indexer = [slice(None)] * self.ndim
|
||||
@@ -509,10 +536,7 @@ class LabelTensor(torch.Tensor):
|
||||
# Update lables
|
||||
labels = self._labels
|
||||
keys_list = list(*dims)
|
||||
labels = {
|
||||
keys_list.index(k): labels[k]
|
||||
for k in labels.keys()
|
||||
}
|
||||
labels = {keys_list.index(k): labels[k] for k in labels.keys()}
|
||||
|
||||
# Assign labels to the new tensor
|
||||
tensor._labels = labels
|
||||
@@ -550,7 +574,7 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
|
||||
if not tensors:
|
||||
raise ValueError('The tensors list must not be empty.')
|
||||
raise ValueError("The tensors list must not be empty.")
|
||||
|
||||
if len(tensors) == 1:
|
||||
return tensors[0]
|
||||
@@ -565,13 +589,13 @@ class LabelTensor(torch.Tensor):
|
||||
last_dim_labels.append(tensor.labels)
|
||||
|
||||
# Construct last dimension labels
|
||||
last_dim_labels = ['+'.join(items) for items in zip(*last_dim_labels)]
|
||||
last_dim_labels = ["+".join(items) for items in zip(*last_dim_labels)]
|
||||
|
||||
# Update the labels for the resulting tensor
|
||||
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
|
||||
labels[tensors[0].ndim - 1] = {
|
||||
'dof': last_dim_labels,
|
||||
'name': tensors[0].name
|
||||
"dof": last_dim_labels,
|
||||
"name": tensors[0].name,
|
||||
}
|
||||
|
||||
return LabelTensor(data, labels)
|
||||
|
||||
Reference in New Issue
Block a user