Update docstring in LabelTensor class

This commit is contained in:
FilippoOlivo
2024-09-28 17:45:54 +02:00
committed by Nicola Demo
parent c53c3d5b84
commit 0353ffdd0f

View File

@@ -114,6 +114,13 @@ class LabelTensor(torch.Tensor):
"""
Stack a list of tensors. For example, given a tensor `a` of shape `(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)`
the resulting tensor is of shape `(n+n',m,dof)`
:param tensors: tensors to concatenate
:type tensors: list(LabelTensor)
:param dim: dimensions on which you want to perform the operation (default 0)
:type dim: int
:rtype: LabelTensor
:raises ValueError: either number dof or dimensions names differ
"""
if len(tensors) == 0:
return []
@@ -177,6 +184,13 @@ class LabelTensor(torch.Tensor):
return out
def update_labels(self, labels):
'''
Update the internal label representation according to the values passed as input.
:param labels: The label(s) to update.
:type labels: dict
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
'''
self.labels = {
idx_: {
'dof': range(self.tensor.shape[idx_]),
@@ -192,5 +206,11 @@ class LabelTensor(torch.Tensor):
self.labels.update(labels)
def init_labels_from_list(self, labels):
'''
Given a list of dof, this method update the internal label representation
:param labels: The label(s) to update.
:type labels: list
'''
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
self.update_labels(last_dim_labels)