Update docstring in LabelTensor class
This commit is contained in:
committed by
Nicola Demo
parent
c53c3d5b84
commit
0353ffdd0f
@@ -114,6 +114,13 @@ class LabelTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
Stack a list of tensors. For example, given a tensor `a` of shape `(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)`
|
Stack a list of tensors. For example, given a tensor `a` of shape `(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)`
|
||||||
the resulting tensor is of shape `(n+n',m,dof)`
|
the resulting tensor is of shape `(n+n',m,dof)`
|
||||||
|
|
||||||
|
:param tensors: tensors to concatenate
|
||||||
|
:type tensors: list(LabelTensor)
|
||||||
|
:param dim: dimensions on which you want to perform the operation (default 0)
|
||||||
|
:type dim: int
|
||||||
|
:rtype: LabelTensor
|
||||||
|
:raises ValueError: either number dof or dimensions names differ
|
||||||
"""
|
"""
|
||||||
if len(tensors) == 0:
|
if len(tensors) == 0:
|
||||||
return []
|
return []
|
||||||
@@ -177,6 +184,13 @@ class LabelTensor(torch.Tensor):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def update_labels(self, labels):
|
def update_labels(self, labels):
|
||||||
|
'''
|
||||||
|
Update the internal label representation according to the values passed as input.
|
||||||
|
|
||||||
|
:param labels: The label(s) to update.
|
||||||
|
:type labels: dict
|
||||||
|
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
|
||||||
|
'''
|
||||||
self.labels = {
|
self.labels = {
|
||||||
idx_: {
|
idx_: {
|
||||||
'dof': range(self.tensor.shape[idx_]),
|
'dof': range(self.tensor.shape[idx_]),
|
||||||
@@ -192,5 +206,11 @@ class LabelTensor(torch.Tensor):
|
|||||||
self.labels.update(labels)
|
self.labels.update(labels)
|
||||||
|
|
||||||
def init_labels_from_list(self, labels):
|
def init_labels_from_list(self, labels):
|
||||||
|
'''
|
||||||
|
Given a list of dof, this method update the internal label representation
|
||||||
|
|
||||||
|
:param labels: The label(s) to update.
|
||||||
|
:type labels: list
|
||||||
|
'''
|
||||||
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
|
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
|
||||||
self.update_labels(last_dim_labels)
|
self.update_labels(last_dim_labels)
|
||||||
Reference in New Issue
Block a user