Minor update on docstring in LabelTensor class

This commit is contained in:
FilippoOlivo
2024-09-30 10:15:45 +02:00
committed by Nicola Demo
parent 0353ffdd0f
commit 16351f95ae

View File

@@ -35,14 +35,15 @@ class LabelTensor(torch.Tensor):
{1: {"name": "space"['a', 'b', 'c']) {1: {"name": "space"['a', 'b', 'c'])
""" """
from .utils import check_consistency self.labels = None
if isinstance(labels, dict): if isinstance(labels, dict):
# check_consistency(labels, dict)
self.update_labels(labels) self.update_labels(labels)
elif isinstance(labels, list): elif isinstance(labels, list):
self.init_labels_from_list(labels) self.init_labels_from_list(labels)
elif isinstance(labels, str): elif isinstance(labels, str):
labels = [labels] labels = [labels]
self.init_labels_from_list(labels)
else:
raise ValueError(f"labels must be list, dict or string.") raise ValueError(f"labels must be list, dict or string.")
def extract(self, label_to_extract): def extract(self, label_to_extract):
@@ -184,13 +185,13 @@ class LabelTensor(torch.Tensor):
return out return out
def update_labels(self, labels): def update_labels(self, labels):
''' """
Update the internal label representation according to the values passed as input. Update the internal label representation according to the values passed as input.
:param labels: The label(s) to update. :param labels: The label(s) to update.
:type labels: dict :type labels: dict
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape :raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
''' """
self.labels = { self.labels = {
idx_: { idx_: {
'dof': range(self.tensor.shape[idx_]), 'dof': range(self.tensor.shape[idx_]),
@@ -206,11 +207,11 @@ class LabelTensor(torch.Tensor):
self.labels.update(labels) self.labels.update(labels)
def init_labels_from_list(self, labels): def init_labels_from_list(self, labels):
''' """
Given a list of dof, this method update the internal label representation Given a list of dof, this method update the internal label representation
:param labels: The label(s) to update. :param labels: The label(s) to update.
:type labels: list :type labels: list
''' """
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}} last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
self.update_labels(last_dim_labels) self.update_labels(last_dim_labels)