Minor update on docstring in LabelTensor class

This commit is contained in:
FilippoOlivo
2024-09-30 10:15:45 +02:00
committed by Nicola Demo
parent 0353ffdd0f
commit 16351f95ae

View File

@@ -35,14 +35,15 @@ class LabelTensor(torch.Tensor):
{1: {"name": "space"['a', 'b', 'c'])
"""
from .utils import check_consistency
self.labels = None
if isinstance(labels, dict):
# check_consistency(labels, dict)
self.update_labels(labels)
elif isinstance(labels, list):
self.init_labels_from_list(labels)
elif isinstance(labels, str):
labels = [labels]
self.init_labels_from_list(labels)
else:
raise ValueError(f"labels must be list, dict or string.")
def extract(self, label_to_extract):
@@ -184,13 +185,13 @@ class LabelTensor(torch.Tensor):
return out
def update_labels(self, labels):
'''
"""
Update the internal label representation according to the values passed as input.
:param labels: The label(s) to update.
:type labels: dict
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
'''
"""
self.labels = {
idx_: {
'dof': range(self.tensor.shape[idx_]),
@@ -206,11 +207,11 @@ class LabelTensor(torch.Tensor):
self.labels.update(labels)
def init_labels_from_list(self, labels):
'''
"""
Given a list of dof, this method update the internal label representation
:param labels: The label(s) to update.
:type labels: list
'''
"""
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
self.update_labels(last_dim_labels)