From c53c3d5b8426bbe89bde1a3672b633c0ba42e002 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Sat, 28 Sep 2024 12:23:16 +0200 Subject: [PATCH] Implement definition of LabelTensor from list, implement cat method (previously stack) and re-implement extract --- pina/label_tensor.py | 469 ++++++++----------------------------- tests/test_label_tensor.py | 109 +-------- 2 files changed, 105 insertions(+), 473 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index ab6045e..f40c36d 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -1,311 +1,8 @@ """ Module for LabelTensor """ -from copy import deepcopy import torch from torch import Tensor - -# class LabelTensor(torch.Tensor): -# """Torch tensor with a label for any column.""" - -# @staticmethod -# def __new__(cls, x, labels, *args, **kwargs): -# return super().__new__(cls, x, *args, **kwargs) - -# def __init__(self, x, labels): -# """ -# Construct a `LabelTensor` by passing a tensor and a list of column -# labels. Such labels uniquely identify the columns of the tensor, -# allowing for an easier manipulation. - -# :param torch.Tensor x: The data tensor. -# :param labels: The labels of the columns. -# :type labels: str | list(str) | tuple(str) - -# :Example: -# >>> from pina import LabelTensor -# >>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c']) -# >>> tensor -# tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01], -# [9.2392e-01, 8.2065e-01, 4.1986e-04], -# [8.9266e-01, 5.5446e-01, 6.3500e-01], -# ..., -# [5.8194e-01, 9.4268e-01, 4.1841e-01], -# [1.0246e-01, 9.5179e-01, 3.7043e-02], -# [9.6150e-01, 8.0656e-01, 8.3824e-01]]) -# >>> tensor.extract('a') -# tensor([[0.0671], -# [0.9239], -# [0.8927], -# ..., -# [0.5819], -# [0.1025], -# [0.9615]]) -# >>> tensor['a'] -# tensor([[0.0671], -# [0.9239], -# [0.8927], -# ..., -# [0.5819], -# [0.1025], -# [0.9615]]) -# >>> tensor.extract(['a', 'b']) -# tensor([[0.0671, 0.4889], -# [0.9239, 0.8207], -# [0.8927, 0.5545], -# ..., -# [0.5819, 0.9427], -# [0.1025, 0.9518], -# [0.9615, 0.8066]]) -# >>> tensor.extract(['b', 'a']) -# tensor([[0.4889, 0.0671], -# [0.8207, 0.9239], -# [0.5545, 0.8927], -# ..., -# [0.9427, 0.5819], -# [0.9518, 0.1025], -# [0.8066, 0.9615]]) -# """ -# if x.ndim == 1: -# x = x.reshape(-1, 1) - -# if isinstance(labels, str): -# labels = [labels] - -# if len(labels) != x.shape[-1]: -# raise ValueError( -# "the tensor has not the same number of columns of " -# "the passed labels." -# ) -# self._labels = labels - -# def __deepcopy__(self, __): -# """ -# Implements deepcopy for label tensor. By default it stores the -# current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__` -# method for creating a new :class:`pina.label_tensor.LabelTensor`. - -# :param __: Placeholder parameter. -# :type __: None -# :return: The deep copy of the :class:`pina.label_tensor.LabelTensor`. -# :rtype: LabelTensor -# """ -# labels = self.labels -# copy_tensor = deepcopy(self.tensor) -# return LabelTensor(copy_tensor, labels) - -# @property -# def labels(self): -# """Property decorator for labels - -# :return: labels of self -# :rtype: list -# """ -# return self._labels - -# @labels.setter -# def labels(self, labels): -# if len(labels) != self.shape[self.ndim - 1]: # small check -# raise ValueError( -# "The tensor has not the same number of columns of " -# "the passed labels." -# ) - -# self._labels = labels # assign the label - -# @staticmethod -# def vstack(label_tensors): -# """ -# Stack tensors vertically. For more details, see -# :meth:`torch.vstack`. - -# :param list(LabelTensor) label_tensors: the tensors to stack. They need -# to have equal labels. -# :return: the stacked tensor -# :rtype: LabelTensor -# """ -# if len(label_tensors) == 0: -# return [] - -# all_labels = [label for lt in label_tensors for label in lt.labels] -# if set(all_labels) != set(label_tensors[0].labels): -# raise RuntimeError("The tensors to stack have different labels") - -# labels = label_tensors[0].labels -# tensors = [lt.extract(labels) for lt in label_tensors] -# return LabelTensor(torch.vstack(tensors), labels) - -# def clone(self, *args, **kwargs): -# """ -# Clone the LabelTensor. For more details, see -# :meth:`torch.Tensor.clone`. - -# :return: A copy of the tensor. -# :rtype: LabelTensor -# """ -# # # used before merging -# # try: -# # out = LabelTensor(super().clone(*args, **kwargs), self.labels) -# # except: -# # out = super().clone(*args, **kwargs) -# out = LabelTensor(super().clone(*args, **kwargs), self.labels) -# return out - -# def to(self, *args, **kwargs): -# """ -# Performs Tensor dtype and/or device conversion. For more details, see -# :meth:`torch.Tensor.to`. -# """ -# tmp = super().to(*args, **kwargs) -# new = self.__class__.clone(self) -# new.data = tmp.data -# return new - -# def select(self, *args, **kwargs): -# """ -# Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`. -# """ -# tmp = super().select(*args, **kwargs) -# tmp._labels = self._labels -# return tmp - -# def cuda(self, *args, **kwargs): -# """ -# Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. -# """ -# tmp = super().cuda(*args, **kwargs) -# new = self.__class__.clone(self) -# new.data = tmp.data -# return new - -# def cpu(self, *args, **kwargs): -# """ -# Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. -# """ -# tmp = super().cpu(*args, **kwargs) -# new = self.__class__.clone(self) -# new.data = tmp.data -# return new - -# def extract(self, label_to_extract): -# """ -# Extract the subset of the original tensor by returning all the columns -# corresponding to the passed ``label_to_extract``. - -# :param label_to_extract: The label(s) to extract. -# :type label_to_extract: str | list(str) | tuple(str) -# :raises TypeError: Labels are not ``str``. -# :raises ValueError: Label to extract is not in the labels ``list``. -# """ - -# if isinstance(label_to_extract, str): -# label_to_extract = [label_to_extract] -# elif isinstance(label_to_extract, (tuple, list)): # TODO -# pass -# else: -# raise TypeError( -# "`label_to_extract` should be a str, or a str iterator" -# ) - -# indeces = [] -# for f in label_to_extract: -# try: -# indeces.append(self.labels.index(f)) -# except ValueError: -# raise ValueError(f"`{f}` not in the labels list") - -# new_data = super(Tensor, self.T).__getitem__(indeces).T -# new_labels = [self.labels[idx] for idx in indeces] - -# extracted_tensor = new_data.as_subclass(LabelTensor) -# extracted_tensor.labels = new_labels - -# return extracted_tensor - -# def detach(self): -# detached = super().detach() -# if hasattr(self, "_labels"): -# detached._labels = self._labels -# return detached - - -# def append(self, lt, mode="std"): -# """ -# Return a copy of the merged tensors. - -# :param LabelTensor lt: The tensor to merge. -# :param str mode: {'std', 'first', 'cross'} -# :return: The merged tensors. -# :rtype: LabelTensor -# """ -# if set(self.labels).intersection(lt.labels): -# raise RuntimeError("The tensors to merge have common labels") - -# new_labels = self.labels + lt.labels -# if mode == "std": -# new_tensor = torch.cat((self, lt), dim=1) -# elif mode == "first": -# raise NotImplementedError -# elif mode == "cross": -# tensor1 = self -# tensor2 = lt -# n1 = tensor1.shape[0] -# n2 = tensor2.shape[0] - -# tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) -# tensor2 = LabelTensor( -# tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels -# ) -# new_tensor = torch.cat((tensor1, tensor2), dim=1) - -# new_tensor = new_tensor.as_subclass(LabelTensor) -# new_tensor.labels = new_labels -# return new_tensor - -# def __getitem__(self, index): -# """ -# Return a copy of the selected tensor. -# """ - -# if isinstance(index, str) or ( -# isinstance(index, (tuple, list)) -# and all(isinstance(a, str) for a in index) -# ): -# return self.extract(index) - -# selected_lt = super(Tensor, self).__getitem__(index) - -# try: -# len_index = len(index) -# except TypeError: -# len_index = 1 - -# if isinstance(index, int) or len_index == 1: -# if selected_lt.ndim == 1: -# selected_lt = selected_lt.reshape(1, -1) -# if hasattr(self, "labels"): -# selected_lt.labels = self.labels -# elif len_index == 2: -# if selected_lt.ndim == 1: -# selected_lt = selected_lt.reshape(-1, 1) -# if hasattr(self, "labels"): -# if isinstance(index[1], list): -# selected_lt.labels = [self.labels[i] for i in index[1]] -# else: -# selected_lt.labels = self.labels[index[1]] -# else: -# selected_lt.labels = self.labels - -# return selected_lt - - -# def __str__(self): -# if hasattr(self, "labels"): -# s = f"labels({str(self.labels)})\n" -# else: -# s = "no labels\n" -# s += super().__str__() -# return s def issubset(a, b): """ Check if a is a subset of b. @@ -334,21 +31,19 @@ class LabelTensor(torch.Tensor): :Example: >>> from pina import LabelTensor >>> tensor = LabelTensor( - >>> torch.rand((2000, 3)), + >>> torch.rand((2000, 3)), {1: {"name": "space"['a', 'b', 'c']) - + """ from .utils import check_consistency - check_consistency(labels, dict) - - self.labels = { - idx_: { - 'dof': range(x.shape[idx_]), - 'name': idx_ - } for idx_ in range(x.ndim) - } - self.labels.update(labels) - + if isinstance(labels, dict): + # check_consistency(labels, dict) + self.update_labels(labels) + elif isinstance(labels, list): + self.init_labels_from_list(labels) + elif isinstance(labels, str): + labels = [labels] + raise ValueError(f"labels must be list, dict or string.") def extract(self, label_to_extract): """ @@ -360,47 +55,48 @@ class LabelTensor(torch.Tensor): :raises TypeError: Labels are not ``str``. :raises ValueError: Label to extract is not in the labels ``list``. """ - if isinstance(label_to_extract, (int, str)): + from copy import deepcopy + if isinstance(label_to_extract, (str, int)): label_to_extract = [label_to_extract] if isinstance(label_to_extract, (tuple, list)): - - for k, v in self.labels.items(): - if issubset(label_to_extract, v['dof']): - break + last_dim_label = self.labels[self.tensor.ndim - 1]['dof'] + if set(label_to_extract).issubset(last_dim_label) is False: + raise ValueError('Cannot extract a dof which is not in the original LabelTensor') + idx_to_extract = [last_dim_label.index(i) for i in label_to_extract] + new_tensor = deepcopy(self.tensor) + new_tensor = new_tensor[..., idx_to_extract] + new_labels = deepcopy(self.labels) + last_dim_new_label = {self.tensor.ndim - 1: { + 'dof': label_to_extract, + 'name': self.labels[self.tensor.ndim - 1]['name'] + }} + new_labels.update(last_dim_new_label) + elif isinstance(label_to_extract, dict): + new_labels = (deepcopy(self.labels)) + new_tensor = deepcopy(self.tensor) + for k, v in label_to_extract.items(): + idx_dim = None + for kl, vl in self.labels.items(): + if vl['name'] == k: + idx_dim = kl + break + dim_labels = self.labels[idx_dim]['dof'] + if isinstance(label_to_extract[k], (int, str)): + label_to_extract[k] = [label_to_extract[k]] + if set(label_to_extract[k]).issubset(dim_labels) is False: + raise ValueError('Cannot extract a dof which is not in the original labeltensor') + idx_to_extract = [dim_labels.index(i) for i in label_to_extract[k]] + indexer = [slice(None)] * idx_dim + [idx_to_extract] + [slice(None)] * (self.tensor.ndim - idx_dim - 1) + new_tensor = new_tensor[indexer] + dim_new_label = {idx_dim: { + 'dof': label_to_extract[k], + 'name': self.labels[idx_dim]['name'] + }} + new_labels.update(dim_new_label) + else: + raise ValueError('labels_to_extract must be str or list or dict') + return LabelTensor(new_tensor, new_labels) - label_to_extract = {v['name']: label_to_extract} - - for k, v in label_to_extract.items(): - if isinstance(v, (int, str)): - label_to_extract[k] = [v] - - indeces = [] - for dim in range(self.ndim): - - boolean_idx = [True] * self.shape[dim] - - for dim_to_extract, dof_to_extract in label_to_extract.items(): - if dim_to_extract == self.labels[dim]['name']: - boolean_idx = [False] * self.shape[dim] - for label in dof_to_extract: - idx_to_keep = self.labels[dim]['dof'].index(label) - boolean_idx[idx_to_keep] = True - - boolean_idx = torch.Tensor(boolean_idx).bool() - - indeces.append(boolean_idx) - - final_shapes = [sum(idx) for idx in indeces] - grids = torch.meshgrid(*indeces) - - ii = grids[0] - for grid in grids[1:]: - ii = torch.logical_and(ii, grid) - - new_tensor = self.tensor[ii].reshape(*final_shapes) - - return LabelTensor(new_tensor, label_to_extract) - def __str__(self): """ returns a string with the representation of the class @@ -409,23 +105,45 @@ class LabelTensor(torch.Tensor): s = '' for key, value in self.labels.items(): s += f"{key}: {value}\n" - s += '\n' + s += '\n' s += super().__str__() return s - + @staticmethod - def stack(tensors): - """ + def cat(tensors, dim=0): + """ + Stack a list of tensors. For example, given a tensor `a` of shape `(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)` + the resulting tensor is of shape `(n+n',m,dof)` """ if len(tensors) == 0: return [] - if len(tensors) == 1: return tensors[0] - - raise NotImplementedError - labels = [tensor.labels for tensor in tensors] - + n_dims = tensors[0].ndim + new_labels_cat_dim = [] + for i in range(n_dims): + name = tensors[0].labels[i]['name'] + if i != dim: + dof = tensors[0].labels[i]['dof'] + for tensor in tensors: + dof_to_check = tensor.labels[i]['dof'] + name_to_check = tensor.labels[i]['name'] + if dof != dof_to_check or name != name_to_check: + raise ValueError('dimensions must have the same dof and name') + else: + for tensor in tensors: + new_labels_cat_dim += tensor.labels[i]['dof'] + name_to_check = tensor.labels[i]['name'] + if name != name_to_check: + raise ValueError('dimensions must have the same dof and name') + new_tensor = torch.cat(tensors, dim=dim) + labels = tensors[0].labels + labels.pop(dim) + new_labels_cat_dim = new_labels_cat_dim if len(set(new_labels_cat_dim)) == len(new_labels_cat_dim) \ + else range(new_tensor.shape[dim]) + labels[dim] = {'dof': new_labels_cat_dim, + 'name': tensors[1].labels[dim]['name']} + return LabelTensor(new_tensor, labels) def requires_grad_(self, mode=True): lt = super().requires_grad_(mode) @@ -454,10 +172,25 @@ class LabelTensor(torch.Tensor): :return: A copy of the tensor. :rtype: LabelTensor """ - # # used before merging - # try: - # out = LabelTensor(super().clone(*args, **kwargs), self.labels) - # except: - # out = super().clone(*args, **kwargs) + out = LabelTensor(super().clone(*args, **kwargs), self.labels) - return out \ No newline at end of file + return out + + def update_labels(self, labels): + self.labels = { + idx_: { + 'dof': range(self.tensor.shape[idx_]), + 'name': idx_ + } for idx_ in range(self.tensor.ndim) + } + tensor_shape = self.tensor.shape + for k, v in labels.items(): + if len(v['dof']) != len(set(v['dof'])): + raise ValueError("dof must be unique") + if len(v['dof']) != tensor_shape[k]: + raise ValueError('Number of dof does not match with tensor dimension') + self.labels.update(labels) + + def init_labels_from_list(self, labels): + last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}} + self.update_labels(last_dim_labels) \ No newline at end of file diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py index 5d5693d..7484a49 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor.py @@ -38,7 +38,7 @@ def test_extract_column(labels, labels_te): assert torch.all(torch.isclose(data[:, 2].reshape(-1, 1), new)) @pytest.mark.parametrize("labels", [labels_row, labels_all]) -@pytest.mark.parametrize("labels_te", [2, [2], {'samples': [2]}]) +@pytest.mark.parametrize("labels_te", [{'samples': [2]}]) def test_extract_row(labels, labels_te): tensor = LabelTensor(data, labels) new = tensor.extract(labels_te) @@ -62,7 +62,7 @@ def test_extract_2D(labels_te): def test_extract_3D(): labels = labels_all - data = torch.rand((20, 3, 4)) + data = torch.rand(20, 3, 4) labels = { 1: { "name": "space", @@ -77,6 +77,7 @@ def test_extract_3D(): 'space': ['x', 'z'], 'time': range(1, 4) } + tensor = LabelTensor(data, labels) new = tensor.extract(labels_te) assert new.ndim == tensor.ndim @@ -86,106 +87,4 @@ def test_extract_3D(): assert torch.all(torch.isclose( data[:, 0::2, 1:4].reshape(20, 2, 3), new - )) - -# def test_labels(): -# tensor = LabelTensor(data, labels) -# assert isinstance(tensor, torch.Tensor) -# assert tensor.labels == labels -# with pytest.raises(ValueError): -# tensor.labels = labels[:-1] - - -# def test_extract(): -# label_to_extract = ['a', 'c'] -# tensor = LabelTensor(data, labels) -# new = tensor.extract(label_to_extract) -# assert new.labels == label_to_extract -# assert new.shape[1] == len(label_to_extract) -# assert torch.all(torch.isclose(data[:, 0::2], new)) - - -# def test_extract_onelabel(): -# label_to_extract = ['a'] -# tensor = LabelTensor(data, labels) -# new = tensor.extract(label_to_extract) -# assert new.ndim == 2 -# assert new.labels == label_to_extract -# assert new.shape[1] == len(label_to_extract) -# assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new)) - - -# def test_wrong_extract(): -# label_to_extract = ['a', 'cc'] -# tensor = LabelTensor(data, labels) -# with pytest.raises(ValueError): -# tensor.extract(label_to_extract) - - -# def test_extract_order(): -# label_to_extract = ['c', 'a'] -# tensor = LabelTensor(data, labels) -# new = tensor.extract(label_to_extract) -# expected = torch.cat( -# (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)), -# dim=1) -# assert new.labels == label_to_extract -# assert new.shape[1] == len(label_to_extract) -# assert torch.all(torch.isclose(expected, new)) - - -# def test_merge(): -# tensor = LabelTensor(data, labels) -# tensor_a = tensor.extract('a') -# tensor_b = tensor.extract('b') -# tensor_c = tensor.extract('c') - -# tensor_bc = tensor_b.append(tensor_c) -# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) - - -# def test_merge2(): -# tensor = LabelTensor(data, labels) -# tensor_b = tensor.extract('b') -# tensor_c = tensor.extract('c') - -# tensor_bc = tensor_b.append(tensor_c) -# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) - - -# def test_getitem(): -# tensor = LabelTensor(data, labels) -# tensor_view = tensor['a'] - -# assert tensor_view.labels == ['a'] -# assert torch.allclose(tensor_view.flatten(), data[:, 0]) - -# tensor_view = tensor['a', 'c'] - -# assert tensor_view.labels == ['a', 'c'] -# assert torch.allclose(tensor_view, data[:, 0::2]) - -# def test_getitem2(): -# tensor = LabelTensor(data, labels) -# tensor_view = tensor[:5] -# assert tensor_view.labels == labels -# assert torch.allclose(tensor_view, data[:5]) - -# idx = torch.randperm(tensor.shape[0]) -# tensor_view = tensor[idx] -# assert tensor_view.labels == labels - - -# def test_slice(): -# tensor = LabelTensor(data, labels) -# tensor_view = tensor[:5, :2] -# assert tensor_view.labels == labels[:2] -# assert torch.allclose(tensor_view, data[:5, :2]) - -# tensor_view2 = tensor[3] -# assert tensor_view2.labels == labels -# assert torch.allclose(tensor_view2, data[3]) - -# tensor_view3 = tensor[:, 2] -# assert tensor_view3.labels == labels[2] -# assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1)) + )) \ No newline at end of file