From eb146ea2ea0b39f8d74216aaebd1eff5e6bb2a31 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 23 Oct 2024 14:54:31 +0200 Subject: [PATCH] Fix bug and improve __getitem__ --- pina/label_tensor.py | 174 ++++++++++++++++++++++++++++--------------- 1 file changed, 116 insertions(+), 58 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 65655e9..62d8795 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -22,9 +22,6 @@ class LabelTensor(torch.Tensor): def tensor(self): return self.as_subclass(Tensor) - def __len__(self) -> int: - return super().__len__() - def __init__(self, x, labels): """ Construct a `LabelTensor` by passing a dict of the labels @@ -75,7 +72,7 @@ class LabelTensor(torch.Tensor): labels = [labels] self.update_labels_from_list(labels) else: - raise ValueError(f"labels must be list, dict or string.") + raise ValueError("labels must be list, dict or string.") self.set_names() def set_names(self): @@ -98,10 +95,9 @@ class LabelTensor(torch.Tensor): label_to_extract = [label_to_extract] if isinstance(label_to_extract, (tuple, list)): return self._extract_from_list(label_to_extract) - elif isinstance(label_to_extract, dict): + if isinstance(label_to_extract, dict): return self._extract_from_dict(label_to_extract) - else: - raise ValueError('labels_to_extract must be str or list or dict') + raise ValueError('labels_to_extract must be str or list or dict') def _extract_from_list(self, labels_to_extract): # Store locally all necessary obj/variables @@ -112,7 +108,8 @@ class LabelTensor(torch.Tensor): # Verify if all the labels in labels_to_extract are in last dimension if set(labels_to_extract).issubset(last_dim_label) is False: - raise ValueError('Cannot extract a dof which is not in the original LabelTensor') + raise ValueError( + 'Cannot extract a dof which is not in the original LabelTensor') # Extract index to extract idx_to_extract = [last_dim_label.index(i) for i in labels_to_extract] @@ -142,9 +139,12 @@ class LabelTensor(torch.Tensor): if isinstance(labels_to_extract[k], (int, str)): labels_to_extract[k] = [labels_to_extract[k]] if set(labels_to_extract[k]).issubset(dim_labels) is False: - raise ValueError('Cannot extract a dof which is not in the original LabelTensor') + raise ValueError( + 'Cannot extract a dof which is not in the original ' + 'LabelTensor') idx_to_extract = [dim_labels.index(i) for i in labels_to_extract[k]] - indexer = [slice(None)] * idx_dim + [idx_to_extract] + [slice(None)] * (ndim - idx_dim - 1) + indexer = [slice(None)] * idx_dim + [idx_to_extract] + [ + slice(None)] * (ndim - idx_dim - 1) new_tensor = new_tensor[indexer] dim_new_label = {idx_dim: { 'dof': labels_to_extract[k], @@ -168,7 +168,8 @@ class LabelTensor(torch.Tensor): @staticmethod def cat(tensors, dim=0): """ - Stack a list of tensors. For example, given a tensor `a` of shape `(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)` + Stack a list of tensors. For example, given a tensor `a` of shape + `(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)` the resulting tensor is of shape `(n+n',m,dof)` :param tensors: tensors to concatenate @@ -182,7 +183,8 @@ class LabelTensor(torch.Tensor): return [] if len(tensors) == 1: return tensors[0] - new_labels_cat_dim = LabelTensor._check_validity_before_cat(tensors, dim) + new_labels_cat_dim = LabelTensor._check_validity_before_cat(tensors, + dim) # Perform cat on tensors new_tensor = torch.cat(tensors, dim=dim) @@ -190,7 +192,8 @@ class LabelTensor(torch.Tensor): # Update labels labels = tensors[0].full_labels labels.pop(dim) - new_labels_cat_dim = new_labels_cat_dim if len(set(new_labels_cat_dim)) == len(new_labels_cat_dim) \ + new_labels_cat_dim = new_labels_cat_dim if len( + set(new_labels_cat_dim)) == len(new_labels_cat_dim) \ else range(new_tensor.shape[dim]) labels[dim] = {'dof': new_labels_cat_dim, 'name': tensors[1].full_labels[dim]['name']} @@ -200,7 +203,8 @@ class LabelTensor(torch.Tensor): def _check_validity_before_cat(tensors, dim): n_dims = tensors[0].ndim new_labels_cat_dim = [] - # Check if names and dof of the labels are the same in all dimensions except in dim + # Check if names and dof of the labels are the same in all dimensions + # except in dim for i in range(n_dims): name = tensors[0].full_labels[i]['name'] if i != dim: @@ -209,13 +213,15 @@ class LabelTensor(torch.Tensor): dof_to_check = tensor.full_labels[i]['dof'] name_to_check = tensor.full_labels[i]['name'] if dof != dof_to_check or name != name_to_check: - raise ValueError('dimensions must have the same dof and name') + raise ValueError( + 'dimensions must have the same dof and name') else: for tensor in tensors: new_labels_cat_dim += tensor.full_labels[i]['dof'] name_to_check = tensor.full_labels[i]['name'] if name != name_to_check: - raise ValueError('Dimensions to concatenate must have the same name') + raise ValueError( + 'Dimensions to concatenate must have the same name') return new_labels_cat_dim def requires_grad_(self, mode=True): @@ -259,11 +265,13 @@ class LabelTensor(torch.Tensor): def update_labels_from_dict(self, labels): """ - Update the internal label representation according to the values passed as input. + Update the internal label representation according to the values passed + as input. :param labels: The label(s) to update. :type labels: dict - :raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape + :raises ValueError: dof list contain duplicates or number of dof does + not match with tensor shape """ tensor_shape = self.tensor.shape # Check dimensionality @@ -271,19 +279,22 @@ class LabelTensor(torch.Tensor): if len(v['dof']) != len(set(v['dof'])): raise ValueError("dof must be unique") if len(v['dof']) != tensor_shape[k]: - raise ValueError('Number of dof does not match with tensor dimension') + raise ValueError( + 'Number of dof does not match with tensor dimension') # Perform update self._labels.update(labels) def update_labels_from_list(self, labels): """ - Given a list of dof, this method update the internal label representation + Given a list of dof, this method update the internal label + representation :param labels: The label(s) to update. :type labels: list """ # Create a dict with labels - last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}} + last_dim_labels = { + self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}} self.update_labels_from_dict(last_dim_labels) @staticmethod @@ -302,15 +313,16 @@ class LabelTensor(torch.Tensor): break # Sum tensors data = torch.zeros(tensors[0].tensor.shape) - for i in range(len(tensors)): - data += tensors[i].tensor + for tensor in tensors: + data += tensor.tensor new_tensor = LabelTensor(data, labels) return new_tensor def append(self, tensor, mode='std'): if mode == 'std': # Call cat on last dimension - new_label_tensor = LabelTensor.cat([self, tensor], dim=self.tensor.ndim - 1) + new_label_tensor = LabelTensor.cat([self, tensor], + dim=self.tensor.ndim - 1) elif mode == 'cross': # Crete tensor and call cat on last dimension tensor1 = self @@ -318,8 +330,10 @@ class LabelTensor(torch.Tensor): n1 = tensor1.shape[0] n2 = tensor2.shape[0] tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) - tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels) - new_label_tensor = LabelTensor.cat([tensor1, tensor2], dim=self.tensor.ndim - 1) + tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), + labels=tensor2.labels) + new_label_tensor = LabelTensor.cat([tensor1, tensor2], + dim=self.tensor.ndim - 1) else: raise ValueError('mode must be either "std" or "cross"') return new_label_tensor @@ -339,47 +353,90 @@ class LabelTensor(torch.Tensor): def __getitem__(self, index): """ - Return a copy of the selected tensor. + TODO: Complete docstring + :param index: + :return: """ - if isinstance(index, str) or (isinstance(index, (tuple, list)) and all(isinstance(a, str) for a in index)): + if isinstance(index, str) or (isinstance(index, (tuple, list)) and all( + isinstance(a, str) for a in index)): return self.extract(index) - selected_lt = super().__getitem__(index) - try: - len_index = len(index) - except TypeError: - len_index = 1 + if isinstance(index, (int, slice)): + return self._getitem_int_slice(index, selected_lt) - if isinstance(index, int) or len_index == 1: - if selected_lt.ndim == 1: - selected_lt = selected_lt.reshape(1, -1) - if hasattr(self, "labels"): - new_labels = deepcopy(self.full_labels) - new_labels.pop(0) - selected_lt.labels = new_labels - elif len(index) == self.tensor.ndim: + if len(index) == self.tensor.ndim: + return self._getitem_full_dim_indexing(index, selected_lt) + + if isinstance(index, torch.Tensor) or ( + isinstance(index, (tuple, list)) and all( + isinstance(x, int) for x in index)): + return self._getitem_permutation(index, selected_lt) + raise ValueError('Not recognized index type') + + def _getitem_int_slice(self, index, selected_lt): + """ + :param index: + :param selected_lt: + :return: + """ + if selected_lt.ndim == 1: + selected_lt = selected_lt.reshape(1, -1) + if hasattr(self, "labels"): new_labels = deepcopy(self.full_labels) - if selected_lt.ndim == 1: - selected_lt = selected_lt.reshape(-1, 1) - for j in range(selected_lt.ndim): - if hasattr(self, "labels"): - if isinstance(index[j], list): - new_labels.update({j: {'dof': [new_labels[j]['dof'][i] for i in index[1]], - 'name': new_labels[j]['name']}}) - else: - new_labels.update({j: {'dof': new_labels[j]['dof'][index[j]], - 'name': new_labels[j]['name']}}) - + to_update_dof = new_labels[0]['dof'][index] + to_update_dof = to_update_dof if isinstance(to_update_dof, ( + tuple, list, range)) else [to_update_dof] + new_labels.update( + {0: {'dof': to_update_dof, 'name': new_labels[0]['name']}} + ) selected_lt.labels = new_labels - else: - new_labels = deepcopy(self.full_labels) - new_labels.update({0: {'dof': list[index], 'name': new_labels[0]['name']}}) - selected_lt.labels = self.labels - return selected_lt + def _getitem_full_dim_indexing(self, index, selected_lt): + new_labels = {} + old_labels = self.full_labels + if selected_lt.ndim == 1: + selected_lt = selected_lt.reshape(-1, 1) + new_labels = deepcopy(old_labels) + new_labels[1].update({'dof': old_labels[1]['dof'][index[1]], + 'name': old_labels[1]['name']}) + idx = 0 + for j in range(selected_lt.ndim): + if not isinstance(index[j], int): + if hasattr(self, "labels"): + new_labels.update( + self._update_label_for_dim(old_labels, index[j], idx)) + idx += 1 + selected_lt.labels = new_labels + return selected_lt + + def _getitem_permutation(self, index, selected_lt): + + new_labels = deepcopy(self.full_labels) + new_labels.update(self._update_label_for_dim(self.full_labels, index, + 0)) + selected_lt.labels = self.labels + return selected_lt + + @staticmethod + def _update_label_for_dim(old_labels, index, dim): + """ + TODO + :param old_labels: + :param index: + :param dim: + :return: + """ + if isinstance(index, list): + return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index], + 'name': old_labels[dim]['name']}} + else: + return {dim: {'dof': old_labels[dim]['dof'][index], + 'name': old_labels[dim]['name']}} + + def sort_labels(self, dim=None): def argsort(lst): return sorted(range(len(lst)), key=lambda x: lst[x]) @@ -391,5 +448,6 @@ class LabelTensor(torch.Tensor): indexer = [slice(None)] * self.tensor.ndim indexer[dim] = sorted_index new_labels = deepcopy(self.full_labels) - new_labels[dim] = {'dof': sorted(labels), 'name': new_labels[dim]['name']} + new_labels[dim] = {'dof': sorted(labels), + 'name': new_labels[dim]['name']} return LabelTensor(self.tensor[indexer], new_labels)