in progress

This commit is contained in:
Nicola Demo
2024-06-18 11:20:58 +02:00
parent e4c69515dd
commit 2b71e0148d
7 changed files with 534 additions and 364 deletions

View File

@@ -9,7 +9,7 @@ __all__ = [
]
from .meta import *
from .label_tensor import LabelTensor
#from .label_tensor import LabelTensor
from .solvers.solver import SolverInterface
from .trainer import Trainer
from .plotter import Plotter

View File

@@ -1,6 +1,6 @@
from torch.utils.data import Dataset
import torch
from pina import LabelTensor
from .label_tensor import LabelTensor
class SamplePointDataset(Dataset):

View File

@@ -1,7 +1,7 @@
import torch
from .location import Location
from pina.geometry import CartesianDomain
from pina import LabelTensor
from ..label_tensor import LabelTensor
from ..utils import check_consistency

View File

@@ -5,6 +5,319 @@ import torch
from torch import Tensor
# class LabelTensor(torch.Tensor):
# """Torch tensor with a label for any column."""
# @staticmethod
# def __new__(cls, x, labels, *args, **kwargs):
# return super().__new__(cls, x, *args, **kwargs)
# def __init__(self, x, labels):
# """
# Construct a `LabelTensor` by passing a tensor and a list of column
# labels. Such labels uniquely identify the columns of the tensor,
# allowing for an easier manipulation.
# :param torch.Tensor x: The data tensor.
# :param labels: The labels of the columns.
# :type labels: str | list(str) | tuple(str)
# :Example:
# >>> from pina import LabelTensor
# >>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c'])
# >>> tensor
# tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01],
# [9.2392e-01, 8.2065e-01, 4.1986e-04],
# [8.9266e-01, 5.5446e-01, 6.3500e-01],
# ...,
# [5.8194e-01, 9.4268e-01, 4.1841e-01],
# [1.0246e-01, 9.5179e-01, 3.7043e-02],
# [9.6150e-01, 8.0656e-01, 8.3824e-01]])
# >>> tensor.extract('a')
# tensor([[0.0671],
# [0.9239],
# [0.8927],
# ...,
# [0.5819],
# [0.1025],
# [0.9615]])
# >>> tensor['a']
# tensor([[0.0671],
# [0.9239],
# [0.8927],
# ...,
# [0.5819],
# [0.1025],
# [0.9615]])
# >>> tensor.extract(['a', 'b'])
# tensor([[0.0671, 0.4889],
# [0.9239, 0.8207],
# [0.8927, 0.5545],
# ...,
# [0.5819, 0.9427],
# [0.1025, 0.9518],
# [0.9615, 0.8066]])
# >>> tensor.extract(['b', 'a'])
# tensor([[0.4889, 0.0671],
# [0.8207, 0.9239],
# [0.5545, 0.8927],
# ...,
# [0.9427, 0.5819],
# [0.9518, 0.1025],
# [0.8066, 0.9615]])
# """
# if x.ndim == 1:
# x = x.reshape(-1, 1)
# if isinstance(labels, str):
# labels = [labels]
# if len(labels) != x.shape[-1]:
# raise ValueError(
# "the tensor has not the same number of columns of "
# "the passed labels."
# )
# self._labels = labels
# def __deepcopy__(self, __):
# """
# Implements deepcopy for label tensor. By default it stores the
# current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__`
# method for creating a new :class:`pina.label_tensor.LabelTensor`.
# :param __: Placeholder parameter.
# :type __: None
# :return: The deep copy of the :class:`pina.label_tensor.LabelTensor`.
# :rtype: LabelTensor
# """
# labels = self.labels
# copy_tensor = deepcopy(self.tensor)
# return LabelTensor(copy_tensor, labels)
# @property
# def labels(self):
# """Property decorator for labels
# :return: labels of self
# :rtype: list
# """
# return self._labels
# @labels.setter
# def labels(self, labels):
# if len(labels) != self.shape[self.ndim - 1]: # small check
# raise ValueError(
# "The tensor has not the same number of columns of "
# "the passed labels."
# )
# self._labels = labels # assign the label
# @staticmethod
# def vstack(label_tensors):
# """
# Stack tensors vertically. For more details, see
# :meth:`torch.vstack`.
# :param list(LabelTensor) label_tensors: the tensors to stack. They need
# to have equal labels.
# :return: the stacked tensor
# :rtype: LabelTensor
# """
# if len(label_tensors) == 0:
# return []
# all_labels = [label for lt in label_tensors for label in lt.labels]
# if set(all_labels) != set(label_tensors[0].labels):
# raise RuntimeError("The tensors to stack have different labels")
# labels = label_tensors[0].labels
# tensors = [lt.extract(labels) for lt in label_tensors]
# return LabelTensor(torch.vstack(tensors), labels)
# def clone(self, *args, **kwargs):
# """
# Clone the LabelTensor. For more details, see
# :meth:`torch.Tensor.clone`.
# :return: A copy of the tensor.
# :rtype: LabelTensor
# """
# # # used before merging
# # try:
# # out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# # except:
# # out = super().clone(*args, **kwargs)
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# return out
# def to(self, *args, **kwargs):
# """
# Performs Tensor dtype and/or device conversion. For more details, see
# :meth:`torch.Tensor.to`.
# """
# tmp = super().to(*args, **kwargs)
# new = self.__class__.clone(self)
# new.data = tmp.data
# return new
# def select(self, *args, **kwargs):
# """
# Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
# """
# tmp = super().select(*args, **kwargs)
# tmp._labels = self._labels
# return tmp
# def cuda(self, *args, **kwargs):
# """
# Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
# """
# tmp = super().cuda(*args, **kwargs)
# new = self.__class__.clone(self)
# new.data = tmp.data
# return new
# def cpu(self, *args, **kwargs):
# """
# Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
# """
# tmp = super().cpu(*args, **kwargs)
# new = self.__class__.clone(self)
# new.data = tmp.data
# return new
# def extract(self, label_to_extract):
# """
# Extract the subset of the original tensor by returning all the columns
# corresponding to the passed ``label_to_extract``.
# :param label_to_extract: The label(s) to extract.
# :type label_to_extract: str | list(str) | tuple(str)
# :raises TypeError: Labels are not ``str``.
# :raises ValueError: Label to extract is not in the labels ``list``.
# """
# if isinstance(label_to_extract, str):
# label_to_extract = [label_to_extract]
# elif isinstance(label_to_extract, (tuple, list)): # TODO
# pass
# else:
# raise TypeError(
# "`label_to_extract` should be a str, or a str iterator"
# )
# indeces = []
# for f in label_to_extract:
# try:
# indeces.append(self.labels.index(f))
# except ValueError:
# raise ValueError(f"`{f}` not in the labels list")
# new_data = super(Tensor, self.T).__getitem__(indeces).T
# new_labels = [self.labels[idx] for idx in indeces]
# extracted_tensor = new_data.as_subclass(LabelTensor)
# extracted_tensor.labels = new_labels
# return extracted_tensor
# def detach(self):
# detached = super().detach()
# if hasattr(self, "_labels"):
# detached._labels = self._labels
# return detached
# def requires_grad_(self, mode=True):
# lt = super().requires_grad_(mode)
# lt.labels = self.labels
# return lt
# def append(self, lt, mode="std"):
# """
# Return a copy of the merged tensors.
# :param LabelTensor lt: The tensor to merge.
# :param str mode: {'std', 'first', 'cross'}
# :return: The merged tensors.
# :rtype: LabelTensor
# """
# if set(self.labels).intersection(lt.labels):
# raise RuntimeError("The tensors to merge have common labels")
# new_labels = self.labels + lt.labels
# if mode == "std":
# new_tensor = torch.cat((self, lt), dim=1)
# elif mode == "first":
# raise NotImplementedError
# elif mode == "cross":
# tensor1 = self
# tensor2 = lt
# n1 = tensor1.shape[0]
# n2 = tensor2.shape[0]
# tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
# tensor2 = LabelTensor(
# tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
# )
# new_tensor = torch.cat((tensor1, tensor2), dim=1)
# new_tensor = new_tensor.as_subclass(LabelTensor)
# new_tensor.labels = new_labels
# return new_tensor
# def __getitem__(self, index):
# """
# Return a copy of the selected tensor.
# """
# if isinstance(index, str) or (
# isinstance(index, (tuple, list))
# and all(isinstance(a, str) for a in index)
# ):
# return self.extract(index)
# selected_lt = super(Tensor, self).__getitem__(index)
# try:
# len_index = len(index)
# except TypeError:
# len_index = 1
# if isinstance(index, int) or len_index == 1:
# if selected_lt.ndim == 1:
# selected_lt = selected_lt.reshape(1, -1)
# if hasattr(self, "labels"):
# selected_lt.labels = self.labels
# elif len_index == 2:
# if selected_lt.ndim == 1:
# selected_lt = selected_lt.reshape(-1, 1)
# if hasattr(self, "labels"):
# if isinstance(index[1], list):
# selected_lt.labels = [self.labels[i] for i in index[1]]
# else:
# selected_lt.labels = self.labels[index[1]]
# else:
# selected_lt.labels = self.labels
# return selected_lt
# def __str__(self):
# if hasattr(self, "labels"):
# s = f"labels({str(self.labels)})\n"
# else:
# s = "no labels\n"
# s += super().__str__()
# return s
def issubset(a, b):
"""
Check if a is a subset of b.
"""
return set(a).issubset(set(b))
class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column."""
@@ -12,180 +325,35 @@ class LabelTensor(torch.Tensor):
def __new__(cls, x, labels, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
@property
def tensor(self):
return self.as_subclass(Tensor)
def __len__(self) -> int:
return super().__len__()
def __init__(self, x, labels):
"""
Construct a `LabelTensor` by passing a tensor and a list of column
labels. Such labels uniquely identify the columns of the tensor,
allowing for an easier manipulation.
:param torch.Tensor x: The data tensor.
:param labels: The labels of the columns.
:type labels: str | list(str) | tuple(str)
Construct a `LabelTensor` by passing a dict of the labels
:Example:
>>> from pina import LabelTensor
>>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c'])
>>> tensor
tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01],
[9.2392e-01, 8.2065e-01, 4.1986e-04],
[8.9266e-01, 5.5446e-01, 6.3500e-01],
...,
[5.8194e-01, 9.4268e-01, 4.1841e-01],
[1.0246e-01, 9.5179e-01, 3.7043e-02],
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
>>> tensor.extract('a')
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor['a']
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor.extract(['a', 'b'])
tensor([[0.0671, 0.4889],
[0.9239, 0.8207],
[0.8927, 0.5545],
...,
[0.5819, 0.9427],
[0.1025, 0.9518],
[0.9615, 0.8066]])
>>> tensor.extract(['b', 'a'])
tensor([[0.4889, 0.0671],
[0.8207, 0.9239],
[0.5545, 0.8927],
...,
[0.9427, 0.5819],
[0.9518, 0.1025],
[0.8066, 0.9615]])
>>> tensor = LabelTensor(
>>> torch.rand((2000, 3)),
{1: {"name": "space"['a', 'b', 'c'])
"""
if x.ndim == 1:
x = x.reshape(-1, 1)
from .utils import check_consistency
check_consistency(labels, dict)
if isinstance(labels, str):
labels = [labels]
self.labels = {
idx_: {
'dof': range(x.shape[idx_]),
'name': idx_
} for idx_ in range(x.ndim)
}
self.labels.update(labels)
if len(labels) != x.shape[-1]:
raise ValueError(
"the tensor has not the same number of columns of "
"the passed labels."
)
self._labels = labels
def __deepcopy__(self, __):
"""
Implements deepcopy for label tensor. By default it stores the
current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__`
method for creating a new :class:`pina.label_tensor.LabelTensor`.
:param __: Placeholder parameter.
:type __: None
:return: The deep copy of the :class:`pina.label_tensor.LabelTensor`.
:rtype: LabelTensor
"""
labels = self.labels
copy_tensor = deepcopy(self.tensor)
return LabelTensor(copy_tensor, labels)
@property
def labels(self):
"""Property decorator for labels
:return: labels of self
:rtype: list
"""
return self._labels
@labels.setter
def labels(self, labels):
if len(labels) != self.shape[self.ndim - 1]: # small check
raise ValueError(
"The tensor has not the same number of columns of "
"the passed labels."
)
self._labels = labels # assign the label
@staticmethod
def vstack(label_tensors):
"""
Stack tensors vertically. For more details, see
:meth:`torch.vstack`.
:param list(LabelTensor) label_tensors: the tensors to stack. They need
to have equal labels.
:return: the stacked tensor
:rtype: LabelTensor
"""
if len(label_tensors) == 0:
return []
all_labels = [label for lt in label_tensors for label in lt.labels]
if set(all_labels) != set(label_tensors[0].labels):
raise RuntimeError("The tensors to stack have different labels")
labels = label_tensors[0].labels
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
:meth:`torch.Tensor.clone`.
:return: A copy of the tensor.
:rtype: LabelTensor
"""
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
"""
tmp = super().to(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
def select(self, *args, **kwargs):
"""
Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
"""
tmp = super().select(*args, **kwargs)
tmp._labels = self._labels
return tmp
def cuda(self, *args, **kwargs):
"""
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
"""
tmp = super().cuda(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
def cpu(self, *args, **kwargs):
"""
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
"""
tmp = super().cpu(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
def extract(self, label_to_extract):
"""
@@ -197,122 +365,52 @@ class LabelTensor(torch.Tensor):
:raises TypeError: Labels are not ``str``.
:raises ValueError: Label to extract is not in the labels ``list``.
"""
if isinstance(label_to_extract, str):
if isinstance(label_to_extract, (int, str)):
label_to_extract = [label_to_extract]
elif isinstance(label_to_extract, (tuple, list)): # TODO
pass
else:
raise TypeError(
"`label_to_extract` should be a str, or a str iterator"
)
if isinstance(label_to_extract, (tuple, list)):
for k, v in self.labels.items():
if issubset(label_to_extract, v['dof']):
break
label_to_extract = {v['name']: label_to_extract}
for k, v in label_to_extract.items():
if isinstance(v, (int, str)):
label_to_extract[k] = [v]
indeces = []
for f in label_to_extract:
try:
indeces.append(self.labels.index(f))
except ValueError:
raise ValueError(f"`{f}` not in the labels list")
for dim in range(self.ndim):
new_data = super(Tensor, self.T).__getitem__(indeces).T
new_labels = [self.labels[idx] for idx in indeces]
boolean_idx = [True] * self.shape[dim]
extracted_tensor = new_data.as_subclass(LabelTensor)
extracted_tensor.labels = new_labels
for dim_to_extract, dof_to_extract in label_to_extract.items():
if dim_to_extract == self.labels[dim]['name']:
boolean_idx = [False] * self.shape[dim]
for label in dof_to_extract:
idx_to_keep = self.labels[dim]['dof'].index(label)
boolean_idx[idx_to_keep] = True
return extracted_tensor
boolean_idx = torch.Tensor(boolean_idx).bool()
def detach(self):
detached = super().detach()
if hasattr(self, "_labels"):
detached._labels = self._labels
return detached
indeces.append(boolean_idx)
def requires_grad_(self, mode=True):
lt = super().requires_grad_(mode)
lt.labels = self.labels
return lt
final_shapes = [sum(idx) for idx in indeces]
grids = torch.meshgrid(*indeces)
def append(self, lt, mode="std"):
"""
Return a copy of the merged tensors.
ii = grids[0]
for grid in grids[1:]:
ii = torch.logical_and(ii, grid)
:param LabelTensor lt: The tensor to merge.
:param str mode: {'std', 'first', 'cross'}
:return: The merged tensors.
:rtype: LabelTensor
"""
if set(self.labels).intersection(lt.labels):
raise RuntimeError("The tensors to merge have common labels")
new_tensor = self.tensor[ii].reshape(*final_shapes)
new_labels = self.labels + lt.labels
if mode == "std":
new_tensor = torch.cat((self, lt), dim=1)
elif mode == "first":
raise NotImplementedError
elif mode == "cross":
tensor1 = self
tensor2 = lt
n1 = tensor1.shape[0]
n2 = tensor2.shape[0]
return LabelTensor(new_tensor, label_to_extract)
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
)
new_tensor = torch.cat((tensor1, tensor2), dim=1)
new_tensor = new_tensor.as_subclass(LabelTensor)
new_tensor.labels = new_labels
return new_tensor
def __getitem__(self, index):
"""
Return a copy of the selected tensor.
"""
if isinstance(index, str) or (
isinstance(index, (tuple, list))
and all(isinstance(a, str) for a in index)
):
return self.extract(index)
selected_lt = super(Tensor, self).__getitem__(index)
try:
len_index = len(index)
except TypeError:
len_index = 1
if isinstance(index, int) or len_index == 1:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(1, -1)
if hasattr(self, "labels"):
selected_lt.labels = self.labels
elif len_index == 2:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(-1, 1)
if hasattr(self, "labels"):
if isinstance(index[1], list):
selected_lt.labels = [self.labels[i] for i in index[1]]
else:
selected_lt.labels = self.labels[index[1]]
else:
selected_lt.labels = self.labels
return selected_lt
@property
def tensor(self):
return self.as_subclass(Tensor)
def __len__(self) -> int:
return super().__len__()
def __str__(self):
if hasattr(self, "labels"):
s = f"labels({str(self.labels)})\n"
else:
s = "no labels\n"
s = ''
for key, value in self.labels.items():
s += f"{key}: {value}\n"
s += '\n'
s += super().__str__()
return s

View File

@@ -4,7 +4,7 @@ Fourier Neural Operator Module.
import torch
import torch.nn as nn
from pina import LabelTensor
from ..label_tensor import LabelTensor
import warnings
from ..utils import check_consistency
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D

View File

@@ -3,7 +3,7 @@
import matplotlib.pyplot as plt
import torch
from pina.callbacks import MetricTracker
from pina import LabelTensor
from .label_tensor import LabelTensor
class Plotter:

View File

@@ -1,119 +1,191 @@
import torch
import pytest
from pina import LabelTensor
from pina.label_tensor import LabelTensor
#import pina
data = torch.rand((20, 3))
labels = ['a', 'b', 'c']
labels_column = {
1: {
"name": "space",
"dof": ['x', 'y', 'z']
}
}
labels_row = {
0: {
"name": "samples",
"dof": range(20)
}
}
labels_all = labels_column | labels_row
def test_constructor():
@pytest.mark.parametrize("labels", [labels_column, labels_row, labels_all])
def test_constructor(labels):
LabelTensor(data, labels)
def test_wrong_constructor():
with pytest.raises(ValueError):
LabelTensor(data, ['a', 'b'])
def test_labels():
@pytest.mark.parametrize("labels", [labels_column, labels_all])
@pytest.mark.parametrize("labels_te", ['z', ['z'], {'space': ['z']}])
def test_extract_column(labels, labels_te):
tensor = LabelTensor(data, labels)
assert isinstance(tensor, torch.Tensor)
assert tensor.labels == labels
with pytest.raises(ValueError):
tensor.labels = labels[:-1]
new = tensor.extract(labels_te)
assert new.ndim == tensor.ndim
assert new.shape[1] == 1
assert new.shape[0] == 20
assert torch.all(torch.isclose(data[:, 2].reshape(-1, 1), new))
def test_extract():
label_to_extract = ['a', 'c']
@pytest.mark.parametrize("labels", [labels_row, labels_all])
@pytest.mark.parametrize("labels_te", [2, [2], {'samples': [2]}])
def test_extract_row(labels, labels_te):
tensor = LabelTensor(data, labels)
new = tensor.extract(label_to_extract)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(data[:, 0::2], new))
new = tensor.extract(labels_te)
assert new.ndim == tensor.ndim
assert new.shape[1] == 3
assert new.shape[0] == 1
assert torch.all(torch.isclose(data[2].reshape(1, -1), new))
def test_extract_onelabel():
label_to_extract = ['a']
@pytest.mark.parametrize("labels_te", [
{'samples': [2], 'space': ['z']},
{'space': 'z', 'samples': 2}
])
def test_extract_2D(labels_te):
labels = labels_all
tensor = LabelTensor(data, labels)
new = tensor.extract(label_to_extract)
assert new.ndim == 2
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
new = tensor.extract(labels_te)
assert new.ndim == tensor.ndim
assert new.shape[1] == 1
assert new.shape[0] == 1
assert torch.all(torch.isclose(data[2,2].reshape(1, 1), new))
def test_wrong_extract():
label_to_extract = ['a', 'cc']
def test_extract_3D():
labels = labels_all
data = torch.rand((20, 3, 4))
labels = {
1: {
"name": "space",
"dof": ['x', 'y', 'z']
},
2: {
"name": "time",
"dof": range(4)
},
}
labels_te = {
'space': ['x', 'z'],
'time': range(1, 4)
}
tensor = LabelTensor(data, labels)
with pytest.raises(ValueError):
tensor.extract(label_to_extract)
new = tensor.extract(labels_te)
assert new.ndim == tensor.ndim
assert new.shape[0] == 20
assert new.shape[1] == 2
assert new.shape[2] == 3
assert torch.all(torch.isclose(
data[:, 0::2, 1:4].reshape(20, 2, 3),
new
))
# def test_labels():
# tensor = LabelTensor(data, labels)
# assert isinstance(tensor, torch.Tensor)
# assert tensor.labels == labels
# with pytest.raises(ValueError):
# tensor.labels = labels[:-1]
def test_extract_order():
label_to_extract = ['c', 'a']
tensor = LabelTensor(data, labels)
new = tensor.extract(label_to_extract)
expected = torch.cat(
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
dim=1)
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new))
# def test_extract():
# label_to_extract = ['a', 'c']
# tensor = LabelTensor(data, labels)
# new = tensor.extract(label_to_extract)
# assert new.labels == label_to_extract
# assert new.shape[1] == len(label_to_extract)
# assert torch.all(torch.isclose(data[:, 0::2], new))
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
# def test_extract_onelabel():
# label_to_extract = ['a']
# tensor = LabelTensor(data, labels)
# new = tensor.extract(label_to_extract)
# assert new.ndim == 2
# assert new.labels == label_to_extract
# assert new.shape[1] == len(label_to_extract)
# assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
def test_merge2():
tensor = LabelTensor(data, labels)
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
# def test_wrong_extract():
# label_to_extract = ['a', 'cc']
# tensor = LabelTensor(data, labels)
# with pytest.raises(ValueError):
# tensor.extract(label_to_extract)
def test_getitem():
tensor = LabelTensor(data, labels)
tensor_view = tensor['a']
assert tensor_view.labels == ['a']
assert torch.allclose(tensor_view.flatten(), data[:, 0])
tensor_view = tensor['a', 'c']
assert tensor_view.labels == ['a', 'c']
assert torch.allclose(tensor_view, data[:, 0::2])
def test_getitem2():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5]
assert tensor_view.labels == labels
assert torch.allclose(tensor_view, data[:5])
idx = torch.randperm(tensor.shape[0])
tensor_view = tensor[idx]
assert tensor_view.labels == labels
# def test_extract_order():
# label_to_extract = ['c', 'a']
# tensor = LabelTensor(data, labels)
# new = tensor.extract(label_to_extract)
# expected = torch.cat(
# (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
# dim=1)
# assert new.labels == label_to_extract
# assert new.shape[1] == len(label_to_extract)
# assert torch.all(torch.isclose(expected, new))
def test_slice():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5, :2]
assert tensor_view.labels == labels[:2]
assert torch.allclose(tensor_view, data[:5, :2])
# def test_merge():
# tensor = LabelTensor(data, labels)
# tensor_a = tensor.extract('a')
# tensor_b = tensor.extract('b')
# tensor_c = tensor.extract('c')
tensor_view2 = tensor[3]
assert tensor_view2.labels == labels
assert torch.allclose(tensor_view2, data[3])
# tensor_bc = tensor_b.append(tensor_c)
# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
tensor_view3 = tensor[:, 2]
assert tensor_view3.labels == labels[2]
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
# def test_merge2():
# tensor = LabelTensor(data, labels)
# tensor_b = tensor.extract('b')
# tensor_c = tensor.extract('c')
# tensor_bc = tensor_b.append(tensor_c)
# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
# def test_getitem():
# tensor = LabelTensor(data, labels)
# tensor_view = tensor['a']
# assert tensor_view.labels == ['a']
# assert torch.allclose(tensor_view.flatten(), data[:, 0])
# tensor_view = tensor['a', 'c']
# assert tensor_view.labels == ['a', 'c']
# assert torch.allclose(tensor_view, data[:, 0::2])
# def test_getitem2():
# tensor = LabelTensor(data, labels)
# tensor_view = tensor[:5]
# assert tensor_view.labels == labels
# assert torch.allclose(tensor_view, data[:5])
# idx = torch.randperm(tensor.shape[0])
# tensor_view = tensor[idx]
# assert tensor_view.labels == labels
# def test_slice():
# tensor = LabelTensor(data, labels)
# tensor_view = tensor[:5, :2]
# assert tensor_view.labels == labels[:2]
# assert torch.allclose(tensor_view, data[:5, :2])
# tensor_view2 = tensor[3]
# assert tensor_view2.labels == labels
# assert torch.allclose(tensor_view2, data[3])
# tensor_view3 = tensor[:, 2]
# assert tensor_view3.labels == labels[2]
# assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))