Implement definition of LabelTensor from list, implement cat method (previously stack) and re-implement extract
This commit is contained in:
committed by
Nicola Demo
parent
a779007b36
commit
c53c3d5b84
@@ -1,311 +1,8 @@
|
||||
""" Module for LabelTensor """
|
||||
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
# class LabelTensor(torch.Tensor):
|
||||
# """Torch tensor with a label for any column."""
|
||||
|
||||
# @staticmethod
|
||||
# def __new__(cls, x, labels, *args, **kwargs):
|
||||
# return super().__new__(cls, x, *args, **kwargs)
|
||||
|
||||
# def __init__(self, x, labels):
|
||||
# """
|
||||
# Construct a `LabelTensor` by passing a tensor and a list of column
|
||||
# labels. Such labels uniquely identify the columns of the tensor,
|
||||
# allowing for an easier manipulation.
|
||||
|
||||
# :param torch.Tensor x: The data tensor.
|
||||
# :param labels: The labels of the columns.
|
||||
# :type labels: str | list(str) | tuple(str)
|
||||
|
||||
# :Example:
|
||||
# >>> from pina import LabelTensor
|
||||
# >>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c'])
|
||||
# >>> tensor
|
||||
# tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01],
|
||||
# [9.2392e-01, 8.2065e-01, 4.1986e-04],
|
||||
# [8.9266e-01, 5.5446e-01, 6.3500e-01],
|
||||
# ...,
|
||||
# [5.8194e-01, 9.4268e-01, 4.1841e-01],
|
||||
# [1.0246e-01, 9.5179e-01, 3.7043e-02],
|
||||
# [9.6150e-01, 8.0656e-01, 8.3824e-01]])
|
||||
# >>> tensor.extract('a')
|
||||
# tensor([[0.0671],
|
||||
# [0.9239],
|
||||
# [0.8927],
|
||||
# ...,
|
||||
# [0.5819],
|
||||
# [0.1025],
|
||||
# [0.9615]])
|
||||
# >>> tensor['a']
|
||||
# tensor([[0.0671],
|
||||
# [0.9239],
|
||||
# [0.8927],
|
||||
# ...,
|
||||
# [0.5819],
|
||||
# [0.1025],
|
||||
# [0.9615]])
|
||||
# >>> tensor.extract(['a', 'b'])
|
||||
# tensor([[0.0671, 0.4889],
|
||||
# [0.9239, 0.8207],
|
||||
# [0.8927, 0.5545],
|
||||
# ...,
|
||||
# [0.5819, 0.9427],
|
||||
# [0.1025, 0.9518],
|
||||
# [0.9615, 0.8066]])
|
||||
# >>> tensor.extract(['b', 'a'])
|
||||
# tensor([[0.4889, 0.0671],
|
||||
# [0.8207, 0.9239],
|
||||
# [0.5545, 0.8927],
|
||||
# ...,
|
||||
# [0.9427, 0.5819],
|
||||
# [0.9518, 0.1025],
|
||||
# [0.8066, 0.9615]])
|
||||
# """
|
||||
# if x.ndim == 1:
|
||||
# x = x.reshape(-1, 1)
|
||||
|
||||
# if isinstance(labels, str):
|
||||
# labels = [labels]
|
||||
|
||||
# if len(labels) != x.shape[-1]:
|
||||
# raise ValueError(
|
||||
# "the tensor has not the same number of columns of "
|
||||
# "the passed labels."
|
||||
# )
|
||||
# self._labels = labels
|
||||
|
||||
# def __deepcopy__(self, __):
|
||||
# """
|
||||
# Implements deepcopy for label tensor. By default it stores the
|
||||
# current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__`
|
||||
# method for creating a new :class:`pina.label_tensor.LabelTensor`.
|
||||
|
||||
# :param __: Placeholder parameter.
|
||||
# :type __: None
|
||||
# :return: The deep copy of the :class:`pina.label_tensor.LabelTensor`.
|
||||
# :rtype: LabelTensor
|
||||
# """
|
||||
# labels = self.labels
|
||||
# copy_tensor = deepcopy(self.tensor)
|
||||
# return LabelTensor(copy_tensor, labels)
|
||||
|
||||
# @property
|
||||
# def labels(self):
|
||||
# """Property decorator for labels
|
||||
|
||||
# :return: labels of self
|
||||
# :rtype: list
|
||||
# """
|
||||
# return self._labels
|
||||
|
||||
# @labels.setter
|
||||
# def labels(self, labels):
|
||||
# if len(labels) != self.shape[self.ndim - 1]: # small check
|
||||
# raise ValueError(
|
||||
# "The tensor has not the same number of columns of "
|
||||
# "the passed labels."
|
||||
# )
|
||||
|
||||
# self._labels = labels # assign the label
|
||||
|
||||
# @staticmethod
|
||||
# def vstack(label_tensors):
|
||||
# """
|
||||
# Stack tensors vertically. For more details, see
|
||||
# :meth:`torch.vstack`.
|
||||
|
||||
# :param list(LabelTensor) label_tensors: the tensors to stack. They need
|
||||
# to have equal labels.
|
||||
# :return: the stacked tensor
|
||||
# :rtype: LabelTensor
|
||||
# """
|
||||
# if len(label_tensors) == 0:
|
||||
# return []
|
||||
|
||||
# all_labels = [label for lt in label_tensors for label in lt.labels]
|
||||
# if set(all_labels) != set(label_tensors[0].labels):
|
||||
# raise RuntimeError("The tensors to stack have different labels")
|
||||
|
||||
# labels = label_tensors[0].labels
|
||||
# tensors = [lt.extract(labels) for lt in label_tensors]
|
||||
# return LabelTensor(torch.vstack(tensors), labels)
|
||||
|
||||
# def clone(self, *args, **kwargs):
|
||||
# """
|
||||
# Clone the LabelTensor. For more details, see
|
||||
# :meth:`torch.Tensor.clone`.
|
||||
|
||||
# :return: A copy of the tensor.
|
||||
# :rtype: LabelTensor
|
||||
# """
|
||||
# # # used before merging
|
||||
# # try:
|
||||
# # out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
# # except:
|
||||
# # out = super().clone(*args, **kwargs)
|
||||
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
# return out
|
||||
|
||||
# def to(self, *args, **kwargs):
|
||||
# """
|
||||
# Performs Tensor dtype and/or device conversion. For more details, see
|
||||
# :meth:`torch.Tensor.to`.
|
||||
# """
|
||||
# tmp = super().to(*args, **kwargs)
|
||||
# new = self.__class__.clone(self)
|
||||
# new.data = tmp.data
|
||||
# return new
|
||||
|
||||
# def select(self, *args, **kwargs):
|
||||
# """
|
||||
# Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
|
||||
# """
|
||||
# tmp = super().select(*args, **kwargs)
|
||||
# tmp._labels = self._labels
|
||||
# return tmp
|
||||
|
||||
# def cuda(self, *args, **kwargs):
|
||||
# """
|
||||
# Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
|
||||
# """
|
||||
# tmp = super().cuda(*args, **kwargs)
|
||||
# new = self.__class__.clone(self)
|
||||
# new.data = tmp.data
|
||||
# return new
|
||||
|
||||
# def cpu(self, *args, **kwargs):
|
||||
# """
|
||||
# Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
|
||||
# """
|
||||
# tmp = super().cpu(*args, **kwargs)
|
||||
# new = self.__class__.clone(self)
|
||||
# new.data = tmp.data
|
||||
# return new
|
||||
|
||||
# def extract(self, label_to_extract):
|
||||
# """
|
||||
# Extract the subset of the original tensor by returning all the columns
|
||||
# corresponding to the passed ``label_to_extract``.
|
||||
|
||||
# :param label_to_extract: The label(s) to extract.
|
||||
# :type label_to_extract: str | list(str) | tuple(str)
|
||||
# :raises TypeError: Labels are not ``str``.
|
||||
# :raises ValueError: Label to extract is not in the labels ``list``.
|
||||
# """
|
||||
|
||||
# if isinstance(label_to_extract, str):
|
||||
# label_to_extract = [label_to_extract]
|
||||
# elif isinstance(label_to_extract, (tuple, list)): # TODO
|
||||
# pass
|
||||
# else:
|
||||
# raise TypeError(
|
||||
# "`label_to_extract` should be a str, or a str iterator"
|
||||
# )
|
||||
|
||||
# indeces = []
|
||||
# for f in label_to_extract:
|
||||
# try:
|
||||
# indeces.append(self.labels.index(f))
|
||||
# except ValueError:
|
||||
# raise ValueError(f"`{f}` not in the labels list")
|
||||
|
||||
# new_data = super(Tensor, self.T).__getitem__(indeces).T
|
||||
# new_labels = [self.labels[idx] for idx in indeces]
|
||||
|
||||
# extracted_tensor = new_data.as_subclass(LabelTensor)
|
||||
# extracted_tensor.labels = new_labels
|
||||
|
||||
# return extracted_tensor
|
||||
|
||||
# def detach(self):
|
||||
# detached = super().detach()
|
||||
# if hasattr(self, "_labels"):
|
||||
# detached._labels = self._labels
|
||||
# return detached
|
||||
|
||||
|
||||
# def append(self, lt, mode="std"):
|
||||
# """
|
||||
# Return a copy of the merged tensors.
|
||||
|
||||
# :param LabelTensor lt: The tensor to merge.
|
||||
# :param str mode: {'std', 'first', 'cross'}
|
||||
# :return: The merged tensors.
|
||||
# :rtype: LabelTensor
|
||||
# """
|
||||
# if set(self.labels).intersection(lt.labels):
|
||||
# raise RuntimeError("The tensors to merge have common labels")
|
||||
|
||||
# new_labels = self.labels + lt.labels
|
||||
# if mode == "std":
|
||||
# new_tensor = torch.cat((self, lt), dim=1)
|
||||
# elif mode == "first":
|
||||
# raise NotImplementedError
|
||||
# elif mode == "cross":
|
||||
# tensor1 = self
|
||||
# tensor2 = lt
|
||||
# n1 = tensor1.shape[0]
|
||||
# n2 = tensor2.shape[0]
|
||||
|
||||
# tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
# tensor2 = LabelTensor(
|
||||
# tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
|
||||
# )
|
||||
# new_tensor = torch.cat((tensor1, tensor2), dim=1)
|
||||
|
||||
# new_tensor = new_tensor.as_subclass(LabelTensor)
|
||||
# new_tensor.labels = new_labels
|
||||
# return new_tensor
|
||||
|
||||
# def __getitem__(self, index):
|
||||
# """
|
||||
# Return a copy of the selected tensor.
|
||||
# """
|
||||
|
||||
# if isinstance(index, str) or (
|
||||
# isinstance(index, (tuple, list))
|
||||
# and all(isinstance(a, str) for a in index)
|
||||
# ):
|
||||
# return self.extract(index)
|
||||
|
||||
# selected_lt = super(Tensor, self).__getitem__(index)
|
||||
|
||||
# try:
|
||||
# len_index = len(index)
|
||||
# except TypeError:
|
||||
# len_index = 1
|
||||
|
||||
# if isinstance(index, int) or len_index == 1:
|
||||
# if selected_lt.ndim == 1:
|
||||
# selected_lt = selected_lt.reshape(1, -1)
|
||||
# if hasattr(self, "labels"):
|
||||
# selected_lt.labels = self.labels
|
||||
# elif len_index == 2:
|
||||
# if selected_lt.ndim == 1:
|
||||
# selected_lt = selected_lt.reshape(-1, 1)
|
||||
# if hasattr(self, "labels"):
|
||||
# if isinstance(index[1], list):
|
||||
# selected_lt.labels = [self.labels[i] for i in index[1]]
|
||||
# else:
|
||||
# selected_lt.labels = self.labels[index[1]]
|
||||
# else:
|
||||
# selected_lt.labels = self.labels
|
||||
|
||||
# return selected_lt
|
||||
|
||||
|
||||
# def __str__(self):
|
||||
# if hasattr(self, "labels"):
|
||||
# s = f"labels({str(self.labels)})\n"
|
||||
# else:
|
||||
# s = "no labels\n"
|
||||
# s += super().__str__()
|
||||
# return s
|
||||
def issubset(a, b):
|
||||
"""
|
||||
Check if a is a subset of b.
|
||||
@@ -339,16 +36,14 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
"""
|
||||
from .utils import check_consistency
|
||||
check_consistency(labels, dict)
|
||||
|
||||
self.labels = {
|
||||
idx_: {
|
||||
'dof': range(x.shape[idx_]),
|
||||
'name': idx_
|
||||
} for idx_ in range(x.ndim)
|
||||
}
|
||||
self.labels.update(labels)
|
||||
|
||||
if isinstance(labels, dict):
|
||||
# check_consistency(labels, dict)
|
||||
self.update_labels(labels)
|
||||
elif isinstance(labels, list):
|
||||
self.init_labels_from_list(labels)
|
||||
elif isinstance(labels, str):
|
||||
labels = [labels]
|
||||
raise ValueError(f"labels must be list, dict or string.")
|
||||
|
||||
def extract(self, label_to_extract):
|
||||
"""
|
||||
@@ -360,46 +55,47 @@ class LabelTensor(torch.Tensor):
|
||||
:raises TypeError: Labels are not ``str``.
|
||||
:raises ValueError: Label to extract is not in the labels ``list``.
|
||||
"""
|
||||
if isinstance(label_to_extract, (int, str)):
|
||||
from copy import deepcopy
|
||||
if isinstance(label_to_extract, (str, int)):
|
||||
label_to_extract = [label_to_extract]
|
||||
if isinstance(label_to_extract, (tuple, list)):
|
||||
|
||||
for k, v in self.labels.items():
|
||||
if issubset(label_to_extract, v['dof']):
|
||||
break
|
||||
|
||||
label_to_extract = {v['name']: label_to_extract}
|
||||
|
||||
last_dim_label = self.labels[self.tensor.ndim - 1]['dof']
|
||||
if set(label_to_extract).issubset(last_dim_label) is False:
|
||||
raise ValueError('Cannot extract a dof which is not in the original LabelTensor')
|
||||
idx_to_extract = [last_dim_label.index(i) for i in label_to_extract]
|
||||
new_tensor = deepcopy(self.tensor)
|
||||
new_tensor = new_tensor[..., idx_to_extract]
|
||||
new_labels = deepcopy(self.labels)
|
||||
last_dim_new_label = {self.tensor.ndim - 1: {
|
||||
'dof': label_to_extract,
|
||||
'name': self.labels[self.tensor.ndim - 1]['name']
|
||||
}}
|
||||
new_labels.update(last_dim_new_label)
|
||||
elif isinstance(label_to_extract, dict):
|
||||
new_labels = (deepcopy(self.labels))
|
||||
new_tensor = deepcopy(self.tensor)
|
||||
for k, v in label_to_extract.items():
|
||||
if isinstance(v, (int, str)):
|
||||
label_to_extract[k] = [v]
|
||||
|
||||
indeces = []
|
||||
for dim in range(self.ndim):
|
||||
|
||||
boolean_idx = [True] * self.shape[dim]
|
||||
|
||||
for dim_to_extract, dof_to_extract in label_to_extract.items():
|
||||
if dim_to_extract == self.labels[dim]['name']:
|
||||
boolean_idx = [False] * self.shape[dim]
|
||||
for label in dof_to_extract:
|
||||
idx_to_keep = self.labels[dim]['dof'].index(label)
|
||||
boolean_idx[idx_to_keep] = True
|
||||
|
||||
boolean_idx = torch.Tensor(boolean_idx).bool()
|
||||
|
||||
indeces.append(boolean_idx)
|
||||
|
||||
final_shapes = [sum(idx) for idx in indeces]
|
||||
grids = torch.meshgrid(*indeces)
|
||||
|
||||
ii = grids[0]
|
||||
for grid in grids[1:]:
|
||||
ii = torch.logical_and(ii, grid)
|
||||
|
||||
new_tensor = self.tensor[ii].reshape(*final_shapes)
|
||||
|
||||
return LabelTensor(new_tensor, label_to_extract)
|
||||
idx_dim = None
|
||||
for kl, vl in self.labels.items():
|
||||
if vl['name'] == k:
|
||||
idx_dim = kl
|
||||
break
|
||||
dim_labels = self.labels[idx_dim]['dof']
|
||||
if isinstance(label_to_extract[k], (int, str)):
|
||||
label_to_extract[k] = [label_to_extract[k]]
|
||||
if set(label_to_extract[k]).issubset(dim_labels) is False:
|
||||
raise ValueError('Cannot extract a dof which is not in the original labeltensor')
|
||||
idx_to_extract = [dim_labels.index(i) for i in label_to_extract[k]]
|
||||
indexer = [slice(None)] * idx_dim + [idx_to_extract] + [slice(None)] * (self.tensor.ndim - idx_dim - 1)
|
||||
new_tensor = new_tensor[indexer]
|
||||
dim_new_label = {idx_dim: {
|
||||
'dof': label_to_extract[k],
|
||||
'name': self.labels[idx_dim]['name']
|
||||
}}
|
||||
new_labels.update(dim_new_label)
|
||||
else:
|
||||
raise ValueError('labels_to_extract must be str or list or dict')
|
||||
return LabelTensor(new_tensor, new_labels)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
@@ -414,18 +110,40 @@ class LabelTensor(torch.Tensor):
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def stack(tensors):
|
||||
def cat(tensors, dim=0):
|
||||
"""
|
||||
Stack a list of tensors. For example, given a tensor `a` of shape `(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)`
|
||||
the resulting tensor is of shape `(n+n',m,dof)`
|
||||
"""
|
||||
if len(tensors) == 0:
|
||||
return []
|
||||
|
||||
if len(tensors) == 1:
|
||||
return tensors[0]
|
||||
|
||||
raise NotImplementedError
|
||||
labels = [tensor.labels for tensor in tensors]
|
||||
|
||||
n_dims = tensors[0].ndim
|
||||
new_labels_cat_dim = []
|
||||
for i in range(n_dims):
|
||||
name = tensors[0].labels[i]['name']
|
||||
if i != dim:
|
||||
dof = tensors[0].labels[i]['dof']
|
||||
for tensor in tensors:
|
||||
dof_to_check = tensor.labels[i]['dof']
|
||||
name_to_check = tensor.labels[i]['name']
|
||||
if dof != dof_to_check or name != name_to_check:
|
||||
raise ValueError('dimensions must have the same dof and name')
|
||||
else:
|
||||
for tensor in tensors:
|
||||
new_labels_cat_dim += tensor.labels[i]['dof']
|
||||
name_to_check = tensor.labels[i]['name']
|
||||
if name != name_to_check:
|
||||
raise ValueError('dimensions must have the same dof and name')
|
||||
new_tensor = torch.cat(tensors, dim=dim)
|
||||
labels = tensors[0].labels
|
||||
labels.pop(dim)
|
||||
new_labels_cat_dim = new_labels_cat_dim if len(set(new_labels_cat_dim)) == len(new_labels_cat_dim) \
|
||||
else range(new_tensor.shape[dim])
|
||||
labels[dim] = {'dof': new_labels_cat_dim,
|
||||
'name': tensors[1].labels[dim]['name']}
|
||||
return LabelTensor(new_tensor, labels)
|
||||
|
||||
def requires_grad_(self, mode=True):
|
||||
lt = super().requires_grad_(mode)
|
||||
@@ -454,10 +172,25 @@ class LabelTensor(torch.Tensor):
|
||||
:return: A copy of the tensor.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# # used before merging
|
||||
# try:
|
||||
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
# except:
|
||||
# out = super().clone(*args, **kwargs)
|
||||
|
||||
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
return out
|
||||
|
||||
def update_labels(self, labels):
|
||||
self.labels = {
|
||||
idx_: {
|
||||
'dof': range(self.tensor.shape[idx_]),
|
||||
'name': idx_
|
||||
} for idx_ in range(self.tensor.ndim)
|
||||
}
|
||||
tensor_shape = self.tensor.shape
|
||||
for k, v in labels.items():
|
||||
if len(v['dof']) != len(set(v['dof'])):
|
||||
raise ValueError("dof must be unique")
|
||||
if len(v['dof']) != tensor_shape[k]:
|
||||
raise ValueError('Number of dof does not match with tensor dimension')
|
||||
self.labels.update(labels)
|
||||
|
||||
def init_labels_from_list(self, labels):
|
||||
last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}}
|
||||
self.update_labels(last_dim_labels)
|
||||
@@ -38,7 +38,7 @@ def test_extract_column(labels, labels_te):
|
||||
assert torch.all(torch.isclose(data[:, 2].reshape(-1, 1), new))
|
||||
|
||||
@pytest.mark.parametrize("labels", [labels_row, labels_all])
|
||||
@pytest.mark.parametrize("labels_te", [2, [2], {'samples': [2]}])
|
||||
@pytest.mark.parametrize("labels_te", [{'samples': [2]}])
|
||||
def test_extract_row(labels, labels_te):
|
||||
tensor = LabelTensor(data, labels)
|
||||
new = tensor.extract(labels_te)
|
||||
@@ -62,7 +62,7 @@ def test_extract_2D(labels_te):
|
||||
|
||||
def test_extract_3D():
|
||||
labels = labels_all
|
||||
data = torch.rand((20, 3, 4))
|
||||
data = torch.rand(20, 3, 4)
|
||||
labels = {
|
||||
1: {
|
||||
"name": "space",
|
||||
@@ -77,6 +77,7 @@ def test_extract_3D():
|
||||
'space': ['x', 'z'],
|
||||
'time': range(1, 4)
|
||||
}
|
||||
|
||||
tensor = LabelTensor(data, labels)
|
||||
new = tensor.extract(labels_te)
|
||||
assert new.ndim == tensor.ndim
|
||||
@@ -87,105 +88,3 @@ def test_extract_3D():
|
||||
data[:, 0::2, 1:4].reshape(20, 2, 3),
|
||||
new
|
||||
))
|
||||
|
||||
# def test_labels():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# assert isinstance(tensor, torch.Tensor)
|
||||
# assert tensor.labels == labels
|
||||
# with pytest.raises(ValueError):
|
||||
# tensor.labels = labels[:-1]
|
||||
|
||||
|
||||
# def test_extract():
|
||||
# label_to_extract = ['a', 'c']
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# new = tensor.extract(label_to_extract)
|
||||
# assert new.labels == label_to_extract
|
||||
# assert new.shape[1] == len(label_to_extract)
|
||||
# assert torch.all(torch.isclose(data[:, 0::2], new))
|
||||
|
||||
|
||||
# def test_extract_onelabel():
|
||||
# label_to_extract = ['a']
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# new = tensor.extract(label_to_extract)
|
||||
# assert new.ndim == 2
|
||||
# assert new.labels == label_to_extract
|
||||
# assert new.shape[1] == len(label_to_extract)
|
||||
# assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
|
||||
|
||||
|
||||
# def test_wrong_extract():
|
||||
# label_to_extract = ['a', 'cc']
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# with pytest.raises(ValueError):
|
||||
# tensor.extract(label_to_extract)
|
||||
|
||||
|
||||
# def test_extract_order():
|
||||
# label_to_extract = ['c', 'a']
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# new = tensor.extract(label_to_extract)
|
||||
# expected = torch.cat(
|
||||
# (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
|
||||
# dim=1)
|
||||
# assert new.labels == label_to_extract
|
||||
# assert new.shape[1] == len(label_to_extract)
|
||||
# assert torch.all(torch.isclose(expected, new))
|
||||
|
||||
|
||||
# def test_merge():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_a = tensor.extract('a')
|
||||
# tensor_b = tensor.extract('b')
|
||||
# tensor_c = tensor.extract('c')
|
||||
|
||||
# tensor_bc = tensor_b.append(tensor_c)
|
||||
# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
|
||||
# def test_merge2():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_b = tensor.extract('b')
|
||||
# tensor_c = tensor.extract('c')
|
||||
|
||||
# tensor_bc = tensor_b.append(tensor_c)
|
||||
# assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
|
||||
# def test_getitem():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_view = tensor['a']
|
||||
|
||||
# assert tensor_view.labels == ['a']
|
||||
# assert torch.allclose(tensor_view.flatten(), data[:, 0])
|
||||
|
||||
# tensor_view = tensor['a', 'c']
|
||||
|
||||
# assert tensor_view.labels == ['a', 'c']
|
||||
# assert torch.allclose(tensor_view, data[:, 0::2])
|
||||
|
||||
# def test_getitem2():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_view = tensor[:5]
|
||||
# assert tensor_view.labels == labels
|
||||
# assert torch.allclose(tensor_view, data[:5])
|
||||
|
||||
# idx = torch.randperm(tensor.shape[0])
|
||||
# tensor_view = tensor[idx]
|
||||
# assert tensor_view.labels == labels
|
||||
|
||||
|
||||
# def test_slice():
|
||||
# tensor = LabelTensor(data, labels)
|
||||
# tensor_view = tensor[:5, :2]
|
||||
# assert tensor_view.labels == labels[:2]
|
||||
# assert torch.allclose(tensor_view, data[:5, :2])
|
||||
|
||||
# tensor_view2 = tensor[3]
|
||||
# assert tensor_view2.labels == labels
|
||||
# assert torch.allclose(tensor_view2, data[3])
|
||||
|
||||
# tensor_view3 = tensor[:, 2]
|
||||
# assert tensor_view3.labels == labels[2]
|
||||
# assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||
|
||||
Reference in New Issue
Block a user