From 2b71e0148dfa834fe6826680a30756a35a265614 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Tue, 18 Jun 2024 11:20:58 +0200 Subject: [PATCH] in progress --- pina/__init__.py | 2 +- pina/dataset.py | 2 +- pina/geometry/simplex.py | 2 +- pina/label_tensor.py | 640 +++++++++++++++++++++---------------- pina/model/fno.py | 2 +- pina/plotter.py | 2 +- tests/test_label_tensor.py | 248 +++++++++----- 7 files changed, 534 insertions(+), 364 deletions(-) diff --git a/pina/__init__.py b/pina/__init__.py index 730b2ea..c63440b 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -9,7 +9,7 @@ __all__ = [ ] from .meta import * -from .label_tensor import LabelTensor +#from .label_tensor import LabelTensor from .solvers.solver import SolverInterface from .trainer import Trainer from .plotter import Plotter diff --git a/pina/dataset.py b/pina/dataset.py index c6a8d29..5f4ba5c 100644 --- a/pina/dataset.py +++ b/pina/dataset.py @@ -1,6 +1,6 @@ from torch.utils.data import Dataset import torch -from pina import LabelTensor +from .label_tensor import LabelTensor class SamplePointDataset(Dataset): diff --git a/pina/geometry/simplex.py b/pina/geometry/simplex.py index b04ad53..15cdc16 100644 --- a/pina/geometry/simplex.py +++ b/pina/geometry/simplex.py @@ -1,7 +1,7 @@ import torch from .location import Location from pina.geometry import CartesianDomain -from pina import LabelTensor +from ..label_tensor import LabelTensor from ..utils import check_consistency diff --git a/pina/label_tensor.py b/pina/label_tensor.py index c8a41f7..7cb2f71 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -5,6 +5,319 @@ import torch from torch import Tensor + +# class LabelTensor(torch.Tensor): +# """Torch tensor with a label for any column.""" + +# @staticmethod +# def __new__(cls, x, labels, *args, **kwargs): +# return super().__new__(cls, x, *args, **kwargs) + +# def __init__(self, x, labels): +# """ +# Construct a `LabelTensor` by passing a tensor and a list of column +# labels. Such labels uniquely identify the columns of the tensor, +# allowing for an easier manipulation. + +# :param torch.Tensor x: The data tensor. +# :param labels: The labels of the columns. +# :type labels: str | list(str) | tuple(str) + +# :Example: +# >>> from pina import LabelTensor +# >>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c']) +# >>> tensor +# tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01], +# [9.2392e-01, 8.2065e-01, 4.1986e-04], +# [8.9266e-01, 5.5446e-01, 6.3500e-01], +# ..., +# [5.8194e-01, 9.4268e-01, 4.1841e-01], +# [1.0246e-01, 9.5179e-01, 3.7043e-02], +# [9.6150e-01, 8.0656e-01, 8.3824e-01]]) +# >>> tensor.extract('a') +# tensor([[0.0671], +# [0.9239], +# [0.8927], +# ..., +# [0.5819], +# [0.1025], +# [0.9615]]) +# >>> tensor['a'] +# tensor([[0.0671], +# [0.9239], +# [0.8927], +# ..., +# [0.5819], +# [0.1025], +# [0.9615]]) +# >>> tensor.extract(['a', 'b']) +# tensor([[0.0671, 0.4889], +# [0.9239, 0.8207], +# [0.8927, 0.5545], +# ..., +# [0.5819, 0.9427], +# [0.1025, 0.9518], +# [0.9615, 0.8066]]) +# >>> tensor.extract(['b', 'a']) +# tensor([[0.4889, 0.0671], +# [0.8207, 0.9239], +# [0.5545, 0.8927], +# ..., +# [0.9427, 0.5819], +# [0.9518, 0.1025], +# [0.8066, 0.9615]]) +# """ +# if x.ndim == 1: +# x = x.reshape(-1, 1) + +# if isinstance(labels, str): +# labels = [labels] + +# if len(labels) != x.shape[-1]: +# raise ValueError( +# "the tensor has not the same number of columns of " +# "the passed labels." +# ) +# self._labels = labels + +# def __deepcopy__(self, __): +# """ +# Implements deepcopy for label tensor. By default it stores the +# current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__` +# method for creating a new :class:`pina.label_tensor.LabelTensor`. + +# :param __: Placeholder parameter. +# :type __: None +# :return: The deep copy of the :class:`pina.label_tensor.LabelTensor`. +# :rtype: LabelTensor +# """ +# labels = self.labels +# copy_tensor = deepcopy(self.tensor) +# return LabelTensor(copy_tensor, labels) + +# @property +# def labels(self): +# """Property decorator for labels + +# :return: labels of self +# :rtype: list +# """ +# return self._labels + +# @labels.setter +# def labels(self, labels): +# if len(labels) != self.shape[self.ndim - 1]: # small check +# raise ValueError( +# "The tensor has not the same number of columns of " +# "the passed labels." +# ) + +# self._labels = labels # assign the label + +# @staticmethod +# def vstack(label_tensors): +# """ +# Stack tensors vertically. For more details, see +# :meth:`torch.vstack`. + +# :param list(LabelTensor) label_tensors: the tensors to stack. They need +# to have equal labels. +# :return: the stacked tensor +# :rtype: LabelTensor +# """ +# if len(label_tensors) == 0: +# return [] + +# all_labels = [label for lt in label_tensors for label in lt.labels] +# if set(all_labels) != set(label_tensors[0].labels): +# raise RuntimeError("The tensors to stack have different labels") + +# labels = label_tensors[0].labels +# tensors = [lt.extract(labels) for lt in label_tensors] +# return LabelTensor(torch.vstack(tensors), labels) + +# def clone(self, *args, **kwargs): +# """ +# Clone the LabelTensor. For more details, see +# :meth:`torch.Tensor.clone`. + +# :return: A copy of the tensor. +# :rtype: LabelTensor +# """ +# # # used before merging +# # try: +# # out = LabelTensor(super().clone(*args, **kwargs), self.labels) +# # except: +# # out = super().clone(*args, **kwargs) +# out = LabelTensor(super().clone(*args, **kwargs), self.labels) +# return out + +# def to(self, *args, **kwargs): +# """ +# Performs Tensor dtype and/or device conversion. For more details, see +# :meth:`torch.Tensor.to`. +# """ +# tmp = super().to(*args, **kwargs) +# new = self.__class__.clone(self) +# new.data = tmp.data +# return new + +# def select(self, *args, **kwargs): +# """ +# Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`. +# """ +# tmp = super().select(*args, **kwargs) +# tmp._labels = self._labels +# return tmp + +# def cuda(self, *args, **kwargs): +# """ +# Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. +# """ +# tmp = super().cuda(*args, **kwargs) +# new = self.__class__.clone(self) +# new.data = tmp.data +# return new + +# def cpu(self, *args, **kwargs): +# """ +# Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. +# """ +# tmp = super().cpu(*args, **kwargs) +# new = self.__class__.clone(self) +# new.data = tmp.data +# return new + +# def extract(self, label_to_extract): +# """ +# Extract the subset of the original tensor by returning all the columns +# corresponding to the passed ``label_to_extract``. + +# :param label_to_extract: The label(s) to extract. +# :type label_to_extract: str | list(str) | tuple(str) +# :raises TypeError: Labels are not ``str``. +# :raises ValueError: Label to extract is not in the labels ``list``. +# """ + +# if isinstance(label_to_extract, str): +# label_to_extract = [label_to_extract] +# elif isinstance(label_to_extract, (tuple, list)): # TODO +# pass +# else: +# raise TypeError( +# "`label_to_extract` should be a str, or a str iterator" +# ) + +# indeces = [] +# for f in label_to_extract: +# try: +# indeces.append(self.labels.index(f)) +# except ValueError: +# raise ValueError(f"`{f}` not in the labels list") + +# new_data = super(Tensor, self.T).__getitem__(indeces).T +# new_labels = [self.labels[idx] for idx in indeces] + +# extracted_tensor = new_data.as_subclass(LabelTensor) +# extracted_tensor.labels = new_labels + +# return extracted_tensor + +# def detach(self): +# detached = super().detach() +# if hasattr(self, "_labels"): +# detached._labels = self._labels +# return detached + +# def requires_grad_(self, mode=True): +# lt = super().requires_grad_(mode) +# lt.labels = self.labels +# return lt + +# def append(self, lt, mode="std"): +# """ +# Return a copy of the merged tensors. + +# :param LabelTensor lt: The tensor to merge. +# :param str mode: {'std', 'first', 'cross'} +# :return: The merged tensors. +# :rtype: LabelTensor +# """ +# if set(self.labels).intersection(lt.labels): +# raise RuntimeError("The tensors to merge have common labels") + +# new_labels = self.labels + lt.labels +# if mode == "std": +# new_tensor = torch.cat((self, lt), dim=1) +# elif mode == "first": +# raise NotImplementedError +# elif mode == "cross": +# tensor1 = self +# tensor2 = lt +# n1 = tensor1.shape[0] +# n2 = tensor2.shape[0] + +# tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) +# tensor2 = LabelTensor( +# tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels +# ) +# new_tensor = torch.cat((tensor1, tensor2), dim=1) + +# new_tensor = new_tensor.as_subclass(LabelTensor) +# new_tensor.labels = new_labels +# return new_tensor + +# def __getitem__(self, index): +# """ +# Return a copy of the selected tensor. +# """ + +# if isinstance(index, str) or ( +# isinstance(index, (tuple, list)) +# and all(isinstance(a, str) for a in index) +# ): +# return self.extract(index) + +# selected_lt = super(Tensor, self).__getitem__(index) + +# try: +# len_index = len(index) +# except TypeError: +# len_index = 1 + +# if isinstance(index, int) or len_index == 1: +# if selected_lt.ndim == 1: +# selected_lt = selected_lt.reshape(1, -1) +# if hasattr(self, "labels"): +# selected_lt.labels = self.labels +# elif len_index == 2: +# if selected_lt.ndim == 1: +# selected_lt = selected_lt.reshape(-1, 1) +# if hasattr(self, "labels"): +# if isinstance(index[1], list): +# selected_lt.labels = [self.labels[i] for i in index[1]] +# else: +# selected_lt.labels = self.labels[index[1]] +# else: +# selected_lt.labels = self.labels + +# return selected_lt + + +# def __str__(self): +# if hasattr(self, "labels"): +# s = f"labels({str(self.labels)})\n" +# else: +# s = "no labels\n" +# s += super().__str__() +# return s + +def issubset(a, b): + """ + Check if a is a subset of b. + """ + return set(a).issubset(set(b)) + class LabelTensor(torch.Tensor): """Torch tensor with a label for any column.""" @@ -12,180 +325,35 @@ class LabelTensor(torch.Tensor): def __new__(cls, x, labels, *args, **kwargs): return super().__new__(cls, x, *args, **kwargs) + @property + def tensor(self): + return self.as_subclass(Tensor) + + def __len__(self) -> int: + return super().__len__() + def __init__(self, x, labels): """ - Construct a `LabelTensor` by passing a tensor and a list of column - labels. Such labels uniquely identify the columns of the tensor, - allowing for an easier manipulation. - - :param torch.Tensor x: The data tensor. - :param labels: The labels of the columns. - :type labels: str | list(str) | tuple(str) + Construct a `LabelTensor` by passing a dict of the labels :Example: >>> from pina import LabelTensor - >>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c']) - >>> tensor - tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01], - [9.2392e-01, 8.2065e-01, 4.1986e-04], - [8.9266e-01, 5.5446e-01, 6.3500e-01], - ..., - [5.8194e-01, 9.4268e-01, 4.1841e-01], - [1.0246e-01, 9.5179e-01, 3.7043e-02], - [9.6150e-01, 8.0656e-01, 8.3824e-01]]) - >>> tensor.extract('a') - tensor([[0.0671], - [0.9239], - [0.8927], - ..., - [0.5819], - [0.1025], - [0.9615]]) - >>> tensor['a'] - tensor([[0.0671], - [0.9239], - [0.8927], - ..., - [0.5819], - [0.1025], - [0.9615]]) - >>> tensor.extract(['a', 'b']) - tensor([[0.0671, 0.4889], - [0.9239, 0.8207], - [0.8927, 0.5545], - ..., - [0.5819, 0.9427], - [0.1025, 0.9518], - [0.9615, 0.8066]]) - >>> tensor.extract(['b', 'a']) - tensor([[0.4889, 0.0671], - [0.8207, 0.9239], - [0.5545, 0.8927], - ..., - [0.9427, 0.5819], - [0.9518, 0.1025], - [0.8066, 0.9615]]) + >>> tensor = LabelTensor( + >>> torch.rand((2000, 3)), + {1: {"name": "space"['a', 'b', 'c']) + """ - if x.ndim == 1: - x = x.reshape(-1, 1) + from .utils import check_consistency + check_consistency(labels, dict) - if isinstance(labels, str): - labels = [labels] + self.labels = { + idx_: { + 'dof': range(x.shape[idx_]), + 'name': idx_ + } for idx_ in range(x.ndim) + } + self.labels.update(labels) - if len(labels) != x.shape[-1]: - raise ValueError( - "the tensor has not the same number of columns of " - "the passed labels." - ) - self._labels = labels - - def __deepcopy__(self, __): - """ - Implements deepcopy for label tensor. By default it stores the - current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__` - method for creating a new :class:`pina.label_tensor.LabelTensor`. - - :param __: Placeholder parameter. - :type __: None - :return: The deep copy of the :class:`pina.label_tensor.LabelTensor`. - :rtype: LabelTensor - """ - labels = self.labels - copy_tensor = deepcopy(self.tensor) - return LabelTensor(copy_tensor, labels) - - @property - def labels(self): - """Property decorator for labels - - :return: labels of self - :rtype: list - """ - return self._labels - - @labels.setter - def labels(self, labels): - if len(labels) != self.shape[self.ndim - 1]: # small check - raise ValueError( - "The tensor has not the same number of columns of " - "the passed labels." - ) - - self._labels = labels # assign the label - - @staticmethod - def vstack(label_tensors): - """ - Stack tensors vertically. For more details, see - :meth:`torch.vstack`. - - :param list(LabelTensor) label_tensors: the tensors to stack. They need - to have equal labels. - :return: the stacked tensor - :rtype: LabelTensor - """ - if len(label_tensors) == 0: - return [] - - all_labels = [label for lt in label_tensors for label in lt.labels] - if set(all_labels) != set(label_tensors[0].labels): - raise RuntimeError("The tensors to stack have different labels") - - labels = label_tensors[0].labels - tensors = [lt.extract(labels) for lt in label_tensors] - return LabelTensor(torch.vstack(tensors), labels) - - def clone(self, *args, **kwargs): - """ - Clone the LabelTensor. For more details, see - :meth:`torch.Tensor.clone`. - - :return: A copy of the tensor. - :rtype: LabelTensor - """ - # # used before merging - # try: - # out = LabelTensor(super().clone(*args, **kwargs), self.labels) - # except: - # out = super().clone(*args, **kwargs) - out = LabelTensor(super().clone(*args, **kwargs), self.labels) - return out - - def to(self, *args, **kwargs): - """ - Performs Tensor dtype and/or device conversion. For more details, see - :meth:`torch.Tensor.to`. - """ - tmp = super().to(*args, **kwargs) - new = self.__class__.clone(self) - new.data = tmp.data - return new - - def select(self, *args, **kwargs): - """ - Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`. - """ - tmp = super().select(*args, **kwargs) - tmp._labels = self._labels - return tmp - - def cuda(self, *args, **kwargs): - """ - Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`. - """ - tmp = super().cuda(*args, **kwargs) - new = self.__class__.clone(self) - new.data = tmp.data - return new - - def cpu(self, *args, **kwargs): - """ - Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`. - """ - tmp = super().cpu(*args, **kwargs) - new = self.__class__.clone(self) - new.data = tmp.data - return new def extract(self, label_to_extract): """ @@ -197,122 +365,52 @@ class LabelTensor(torch.Tensor): :raises TypeError: Labels are not ``str``. :raises ValueError: Label to extract is not in the labels ``list``. """ - - if isinstance(label_to_extract, str): + if isinstance(label_to_extract, (int, str)): label_to_extract = [label_to_extract] - elif isinstance(label_to_extract, (tuple, list)): # TODO - pass - else: - raise TypeError( - "`label_to_extract` should be a str, or a str iterator" - ) + if isinstance(label_to_extract, (tuple, list)): + + for k, v in self.labels.items(): + if issubset(label_to_extract, v['dof']): + break + + label_to_extract = {v['name']: label_to_extract} + + for k, v in label_to_extract.items(): + if isinstance(v, (int, str)): + label_to_extract[k] = [v] indeces = [] - for f in label_to_extract: - try: - indeces.append(self.labels.index(f)) - except ValueError: - raise ValueError(f"`{f}` not in the labels list") + for dim in range(self.ndim): - new_data = super(Tensor, self.T).__getitem__(indeces).T - new_labels = [self.labels[idx] for idx in indeces] + boolean_idx = [True] * self.shape[dim] - extracted_tensor = new_data.as_subclass(LabelTensor) - extracted_tensor.labels = new_labels + for dim_to_extract, dof_to_extract in label_to_extract.items(): + if dim_to_extract == self.labels[dim]['name']: + boolean_idx = [False] * self.shape[dim] + for label in dof_to_extract: + idx_to_keep = self.labels[dim]['dof'].index(label) + boolean_idx[idx_to_keep] = True - return extracted_tensor + boolean_idx = torch.Tensor(boolean_idx).bool() - def detach(self): - detached = super().detach() - if hasattr(self, "_labels"): - detached._labels = self._labels - return detached + indeces.append(boolean_idx) - def requires_grad_(self, mode=True): - lt = super().requires_grad_(mode) - lt.labels = self.labels - return lt + final_shapes = [sum(idx) for idx in indeces] + grids = torch.meshgrid(*indeces) - def append(self, lt, mode="std"): - """ - Return a copy of the merged tensors. + ii = grids[0] + for grid in grids[1:]: + ii = torch.logical_and(ii, grid) - :param LabelTensor lt: The tensor to merge. - :param str mode: {'std', 'first', 'cross'} - :return: The merged tensors. - :rtype: LabelTensor - """ - if set(self.labels).intersection(lt.labels): - raise RuntimeError("The tensors to merge have common labels") + new_tensor = self.tensor[ii].reshape(*final_shapes) - new_labels = self.labels + lt.labels - if mode == "std": - new_tensor = torch.cat((self, lt), dim=1) - elif mode == "first": - raise NotImplementedError - elif mode == "cross": - tensor1 = self - tensor2 = lt - n1 = tensor1.shape[0] - n2 = tensor2.shape[0] - - tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) - tensor2 = LabelTensor( - tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels - ) - new_tensor = torch.cat((tensor1, tensor2), dim=1) - - new_tensor = new_tensor.as_subclass(LabelTensor) - new_tensor.labels = new_labels - return new_tensor - - def __getitem__(self, index): - """ - Return a copy of the selected tensor. - """ - - if isinstance(index, str) or ( - isinstance(index, (tuple, list)) - and all(isinstance(a, str) for a in index) - ): - return self.extract(index) - - selected_lt = super(Tensor, self).__getitem__(index) - - try: - len_index = len(index) - except TypeError: - len_index = 1 - - if isinstance(index, int) or len_index == 1: - if selected_lt.ndim == 1: - selected_lt = selected_lt.reshape(1, -1) - if hasattr(self, "labels"): - selected_lt.labels = self.labels - elif len_index == 2: - if selected_lt.ndim == 1: - selected_lt = selected_lt.reshape(-1, 1) - if hasattr(self, "labels"): - if isinstance(index[1], list): - selected_lt.labels = [self.labels[i] for i in index[1]] - else: - selected_lt.labels = self.labels[index[1]] - else: - selected_lt.labels = self.labels - - return selected_lt - - @property - def tensor(self): - return self.as_subclass(Tensor) - - def __len__(self) -> int: - return super().__len__() + return LabelTensor(new_tensor, label_to_extract) + def __str__(self): - if hasattr(self, "labels"): - s = f"labels({str(self.labels)})\n" - else: - s = "no labels\n" + s = '' + for key, value in self.labels.items(): + s += f"{key}: {value}\n" + s += '\n' s += super().__str__() - return s + return s \ No newline at end of file diff --git a/pina/model/fno.py b/pina/model/fno.py index 910b416..92ca183 100644 --- a/pina/model/fno.py +++ b/pina/model/fno.py @@ -4,7 +4,7 @@ Fourier Neural Operator Module. import torch import torch.nn as nn -from pina import LabelTensor +from ..label_tensor import LabelTensor import warnings from ..utils import check_consistency from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D diff --git a/pina/plotter.py b/pina/plotter.py index 041ef05..eedec80 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import torch from pina.callbacks import MetricTracker -from pina import LabelTensor +from .label_tensor import LabelTensor class Plotter: diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py index 05dace5..5d5693d 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor.py @@ -1,119 +1,191 @@ import torch import pytest -from pina import LabelTensor +from pina.label_tensor import LabelTensor +#import pina data = torch.rand((20, 3)) -labels = ['a', 'b', 'c'] +labels_column = { + 1: { + "name": "space", + "dof": ['x', 'y', 'z'] + } +} +labels_row = { + 0: { + "name": "samples", + "dof": range(20) + } +} +labels_all = labels_column | labels_row - -def test_constructor(): +@pytest.mark.parametrize("labels", [labels_column, labels_row, labels_all]) +def test_constructor(labels): LabelTensor(data, labels) - def test_wrong_constructor(): with pytest.raises(ValueError): LabelTensor(data, ['a', 'b']) - -def test_labels(): +@pytest.mark.parametrize("labels", [labels_column, labels_all]) +@pytest.mark.parametrize("labels_te", ['z', ['z'], {'space': ['z']}]) +def test_extract_column(labels, labels_te): tensor = LabelTensor(data, labels) - assert isinstance(tensor, torch.Tensor) - assert tensor.labels == labels - with pytest.raises(ValueError): - tensor.labels = labels[:-1] + new = tensor.extract(labels_te) + assert new.ndim == tensor.ndim + assert new.shape[1] == 1 + assert new.shape[0] == 20 + assert torch.all(torch.isclose(data[:, 2].reshape(-1, 1), new)) - -def test_extract(): - label_to_extract = ['a', 'c'] +@pytest.mark.parametrize("labels", [labels_row, labels_all]) +@pytest.mark.parametrize("labels_te", [2, [2], {'samples': [2]}]) +def test_extract_row(labels, labels_te): tensor = LabelTensor(data, labels) - new = tensor.extract(label_to_extract) - assert new.labels == label_to_extract - assert new.shape[1] == len(label_to_extract) - assert torch.all(torch.isclose(data[:, 0::2], new)) + new = tensor.extract(labels_te) + assert new.ndim == tensor.ndim + assert new.shape[1] == 3 + assert new.shape[0] == 1 + assert torch.all(torch.isclose(data[2].reshape(1, -1), new)) - -def test_extract_onelabel(): - label_to_extract = ['a'] +@pytest.mark.parametrize("labels_te", [ + {'samples': [2], 'space': ['z']}, + {'space': 'z', 'samples': 2} +]) +def test_extract_2D(labels_te): + labels = labels_all tensor = LabelTensor(data, labels) - new = tensor.extract(label_to_extract) - assert new.ndim == 2 - assert new.labels == label_to_extract - assert new.shape[1] == len(label_to_extract) - assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new)) + new = tensor.extract(labels_te) + assert new.ndim == tensor.ndim + assert new.shape[1] == 1 + assert new.shape[0] == 1 + assert torch.all(torch.isclose(data[2,2].reshape(1, 1), new)) - -def test_wrong_extract(): - label_to_extract = ['a', 'cc'] +def test_extract_3D(): + labels = labels_all + data = torch.rand((20, 3, 4)) + labels = { + 1: { + "name": "space", + "dof": ['x', 'y', 'z'] + }, + 2: { + "name": "time", + "dof": range(4) + }, + } + labels_te = { + 'space': ['x', 'z'], + 'time': range(1, 4) + } tensor = LabelTensor(data, labels) - with pytest.raises(ValueError): - tensor.extract(label_to_extract) + new = tensor.extract(labels_te) + assert new.ndim == tensor.ndim + assert new.shape[0] == 20 + assert new.shape[1] == 2 + assert new.shape[2] == 3 + assert torch.all(torch.isclose( + data[:, 0::2, 1:4].reshape(20, 2, 3), + new + )) + +# def test_labels(): +# tensor = LabelTensor(data, labels) +# assert isinstance(tensor, torch.Tensor) +# assert tensor.labels == labels +# with pytest.raises(ValueError): +# tensor.labels = labels[:-1] -def test_extract_order(): - label_to_extract = ['c', 'a'] - tensor = LabelTensor(data, labels) - new = tensor.extract(label_to_extract) - expected = torch.cat( - (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)), - dim=1) - assert new.labels == label_to_extract - assert new.shape[1] == len(label_to_extract) - assert torch.all(torch.isclose(expected, new)) +# def test_extract(): +# label_to_extract = ['a', 'c'] +# tensor = LabelTensor(data, labels) +# new = tensor.extract(label_to_extract) +# assert new.labels == label_to_extract +# assert new.shape[1] == len(label_to_extract) +# assert torch.all(torch.isclose(data[:, 0::2], new)) -def test_merge(): - tensor = LabelTensor(data, labels) - tensor_a = tensor.extract('a') - tensor_b = tensor.extract('b') - tensor_c = tensor.extract('c') - - tensor_bc = tensor_b.append(tensor_c) - assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) +# def test_extract_onelabel(): +# label_to_extract = ['a'] +# tensor = LabelTensor(data, labels) +# new = tensor.extract(label_to_extract) +# assert new.ndim == 2 +# assert new.labels == label_to_extract +# assert new.shape[1] == len(label_to_extract) +# assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new)) -def test_merge2(): - tensor = LabelTensor(data, labels) - tensor_b = tensor.extract('b') - tensor_c = tensor.extract('c') - - tensor_bc = tensor_b.append(tensor_c) - assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) +# def test_wrong_extract(): +# label_to_extract = ['a', 'cc'] +# tensor = LabelTensor(data, labels) +# with pytest.raises(ValueError): +# tensor.extract(label_to_extract) -def test_getitem(): - tensor = LabelTensor(data, labels) - tensor_view = tensor['a'] - - assert tensor_view.labels == ['a'] - assert torch.allclose(tensor_view.flatten(), data[:, 0]) - - tensor_view = tensor['a', 'c'] - - assert tensor_view.labels == ['a', 'c'] - assert torch.allclose(tensor_view, data[:, 0::2]) - -def test_getitem2(): - tensor = LabelTensor(data, labels) - tensor_view = tensor[:5] - assert tensor_view.labels == labels - assert torch.allclose(tensor_view, data[:5]) - - idx = torch.randperm(tensor.shape[0]) - tensor_view = tensor[idx] - assert tensor_view.labels == labels +# def test_extract_order(): +# label_to_extract = ['c', 'a'] +# tensor = LabelTensor(data, labels) +# new = tensor.extract(label_to_extract) +# expected = torch.cat( +# (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)), +# dim=1) +# assert new.labels == label_to_extract +# assert new.shape[1] == len(label_to_extract) +# assert torch.all(torch.isclose(expected, new)) -def test_slice(): - tensor = LabelTensor(data, labels) - tensor_view = tensor[:5, :2] - assert tensor_view.labels == labels[:2] - assert torch.allclose(tensor_view, data[:5, :2]) +# def test_merge(): +# tensor = LabelTensor(data, labels) +# tensor_a = tensor.extract('a') +# tensor_b = tensor.extract('b') +# tensor_c = tensor.extract('c') - tensor_view2 = tensor[3] - assert tensor_view2.labels == labels - assert torch.allclose(tensor_view2, data[3]) +# tensor_bc = tensor_b.append(tensor_c) +# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) - tensor_view3 = tensor[:, 2] - assert tensor_view3.labels == labels[2] - assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1)) + +# def test_merge2(): +# tensor = LabelTensor(data, labels) +# tensor_b = tensor.extract('b') +# tensor_c = tensor.extract('c') + +# tensor_bc = tensor_b.append(tensor_c) +# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) + + +# def test_getitem(): +# tensor = LabelTensor(data, labels) +# tensor_view = tensor['a'] + +# assert tensor_view.labels == ['a'] +# assert torch.allclose(tensor_view.flatten(), data[:, 0]) + +# tensor_view = tensor['a', 'c'] + +# assert tensor_view.labels == ['a', 'c'] +# assert torch.allclose(tensor_view, data[:, 0::2]) + +# def test_getitem2(): +# tensor = LabelTensor(data, labels) +# tensor_view = tensor[:5] +# assert tensor_view.labels == labels +# assert torch.allclose(tensor_view, data[:5]) + +# idx = torch.randperm(tensor.shape[0]) +# tensor_view = tensor[idx] +# assert tensor_view.labels == labels + + +# def test_slice(): +# tensor = LabelTensor(data, labels) +# tensor_view = tensor[:5, :2] +# assert tensor_view.labels == labels[:2] +# assert torch.allclose(tensor_view, data[:5, :2]) + +# tensor_view2 = tensor[3] +# assert tensor_view2.labels == labels +# assert torch.allclose(tensor_view2, data[3]) + +# tensor_view3 = tensor[:, 2] +# assert tensor_view3.labels == labels[2] +# assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))