Minor update on docstring in LabelTensor class
This commit is contained in:
committed by
Nicola Demo
parent
0353ffdd0f
commit
16351f95ae
@@ -35,14 +35,15 @@ class LabelTensor(torch.Tensor):
|
||||
{1: {"name": "space"['a', 'b', 'c'])
|
||||
|
||||
"""
|
||||
from .utils import check_consistency
|
||||
self.labels = None
|
||||
if isinstance(labels, dict):
|
||||
# check_consistency(labels, dict)
|
||||
self.update_labels(labels)
|
||||
elif isinstance(labels, list):
|
||||
self.init_labels_from_list(labels)
|
||||
elif isinstance(labels, str):
|
||||
labels = [labels]
|
||||
self.init_labels_from_list(labels)
|
||||
else:
|
||||
raise ValueError(f"labels must be list, dict or string.")
|
||||
|
||||
def extract(self, label_to_extract):
|
||||
@@ -184,13 +185,13 @@ class LabelTensor(torch.Tensor):
|
||||
return out
|
||||
|
||||
def update_labels(self, labels):
|
||||
'''
|
||||
"""
|
||||
Update the internal label representation according to the values passed as input.
|
||||
|
||||
:param labels: The label(s) to update.
|
||||
:type labels: dict
|
||||
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
|
||||
'''
|
||||
"""
|
||||
self.labels = {
|
||||
idx_: {
|
||||
'dof': range(self.tensor.shape[idx_]),
|
||||
@@ -206,11 +207,11 @@ class LabelTensor(torch.Tensor):
|
||||
self.labels.update(labels)
|
||||
|
||||
def init_labels_from_list(self, labels):
|
||||
'''
|
||||
"""
|
||||
Given a list of dof, this method update the internal label representation
|
||||
|
||||
:param labels: The label(s) to update.
|
||||
:type labels: list
|
||||
'''
|
||||
"""
|
||||
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
|
||||
self.update_labels(last_dim_labels)
|
||||
Reference in New Issue
Block a user