diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 3669272..7646dd8 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -35,14 +35,15 @@ class LabelTensor(torch.Tensor): {1: {"name": "space"['a', 'b', 'c']) """ - from .utils import check_consistency + self.labels = None if isinstance(labels, dict): - # check_consistency(labels, dict) self.update_labels(labels) elif isinstance(labels, list): self.init_labels_from_list(labels) elif isinstance(labels, str): labels = [labels] + self.init_labels_from_list(labels) + else: raise ValueError(f"labels must be list, dict or string.") def extract(self, label_to_extract): @@ -184,13 +185,13 @@ class LabelTensor(torch.Tensor): return out def update_labels(self, labels): - ''' + """ Update the internal label representation according to the values passed as input. :param labels: The label(s) to update. :type labels: dict :raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape - ''' + """ self.labels = { idx_: { 'dof': range(self.tensor.shape[idx_]), @@ -206,11 +207,11 @@ class LabelTensor(torch.Tensor): self.labels.update(labels) def init_labels_from_list(self, labels): - ''' + """ Given a list of dof, this method update the internal label representation :param labels: The label(s) to update. :type labels: list - ''' + """ last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}} self.update_labels(last_dim_labels) \ No newline at end of file