Minor update on docstring in LabelTensor class
This commit is contained in:
committed by
Nicola Demo
parent
0353ffdd0f
commit
16351f95ae
@@ -35,14 +35,15 @@ class LabelTensor(torch.Tensor):
|
|||||||
{1: {"name": "space"['a', 'b', 'c'])
|
{1: {"name": "space"['a', 'b', 'c'])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from .utils import check_consistency
|
self.labels = None
|
||||||
if isinstance(labels, dict):
|
if isinstance(labels, dict):
|
||||||
# check_consistency(labels, dict)
|
|
||||||
self.update_labels(labels)
|
self.update_labels(labels)
|
||||||
elif isinstance(labels, list):
|
elif isinstance(labels, list):
|
||||||
self.init_labels_from_list(labels)
|
self.init_labels_from_list(labels)
|
||||||
elif isinstance(labels, str):
|
elif isinstance(labels, str):
|
||||||
labels = [labels]
|
labels = [labels]
|
||||||
|
self.init_labels_from_list(labels)
|
||||||
|
else:
|
||||||
raise ValueError(f"labels must be list, dict or string.")
|
raise ValueError(f"labels must be list, dict or string.")
|
||||||
|
|
||||||
def extract(self, label_to_extract):
|
def extract(self, label_to_extract):
|
||||||
@@ -184,13 +185,13 @@ class LabelTensor(torch.Tensor):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def update_labels(self, labels):
|
def update_labels(self, labels):
|
||||||
'''
|
"""
|
||||||
Update the internal label representation according to the values passed as input.
|
Update the internal label representation according to the values passed as input.
|
||||||
|
|
||||||
:param labels: The label(s) to update.
|
:param labels: The label(s) to update.
|
||||||
:type labels: dict
|
:type labels: dict
|
||||||
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
|
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
|
||||||
'''
|
"""
|
||||||
self.labels = {
|
self.labels = {
|
||||||
idx_: {
|
idx_: {
|
||||||
'dof': range(self.tensor.shape[idx_]),
|
'dof': range(self.tensor.shape[idx_]),
|
||||||
@@ -206,11 +207,11 @@ class LabelTensor(torch.Tensor):
|
|||||||
self.labels.update(labels)
|
self.labels.update(labels)
|
||||||
|
|
||||||
def init_labels_from_list(self, labels):
|
def init_labels_from_list(self, labels):
|
||||||
'''
|
"""
|
||||||
Given a list of dof, this method update the internal label representation
|
Given a list of dof, this method update the internal label representation
|
||||||
|
|
||||||
:param labels: The label(s) to update.
|
:param labels: The label(s) to update.
|
||||||
:type labels: list
|
:type labels: list
|
||||||
'''
|
"""
|
||||||
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
|
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
|
||||||
self.update_labels(last_dim_labels)
|
self.update_labels(last_dim_labels)
|
||||||
Reference in New Issue
Block a user