adjust LabelTensor (inheritance)
This commit is contained in:
7
docs/source/_rst/code.rst
Normal file
7
docs/source/_rst/code.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
Code Documentation
|
||||||
|
==================
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 3
|
||||||
|
|
||||||
|
LabelTensor <label_tensor.rst>
|
||||||
12
docs/source/_rst/label_tensor.rst
Normal file
12
docs/source/_rst/label_tensor.rst
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
LabelTensor
|
||||||
|
===========
|
||||||
|
.. currentmodule:: pina.label_tensor
|
||||||
|
|
||||||
|
.. automodule:: pina.label_tensor
|
||||||
|
|
||||||
|
.. autoclass:: LabelTensor
|
||||||
|
:members:
|
||||||
|
:private-members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
:noindex:
|
||||||
@@ -47,12 +47,14 @@ extensions = [
|
|||||||
'sphinx.ext.ifconfig',
|
'sphinx.ext.ifconfig',
|
||||||
'sphinx.ext.mathjax',
|
'sphinx.ext.mathjax',
|
||||||
]
|
]
|
||||||
|
autosummary_generate = True
|
||||||
|
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
'python': ('http://docs.python.org/2', None), 'numpy':
|
'python': ('http://docs.python.org/2', None),
|
||||||
('http://docs.scipy.org/doc/numpy/', None), 'scipy':
|
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
|
||||||
('http://docs.scipy.org/doc/scipy/reference/', None), 'matplotlib':
|
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
|
||||||
('http://matplotlib.sourceforge.net/', None)
|
'matplotlib': ('http://matplotlib.sourceforge.net/', None),
|
||||||
|
'torch': ('https://pytorch.org/docs/stable/', None)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add any paths that contain templates here, relative to this directory.
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ solve problems in a continuous and nonlinear settings.
|
|||||||
:caption: Package Documentation:
|
:caption: Package Documentation:
|
||||||
|
|
||||||
Installation <_rst/installation>
|
Installation <_rst/installation>
|
||||||
|
API <_rst/code>
|
||||||
Contributing <_rst/contributing>
|
Contributing <_rst/contributing>
|
||||||
License <LICENSE.rst>
|
License <LICENSE.rst>
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ solve problems in a continuous and nonlinear settings.
|
|||||||
.. ........................................................................................
|
.. ........................................................................................
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 1
|
||||||
:numbered:
|
:numbered:
|
||||||
:caption: Tutorials:
|
:caption: Tutorials:
|
||||||
|
|
||||||
|
|||||||
@@ -1,64 +1,119 @@
|
|||||||
|
""" Module for LabelTensor """
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class LabelTensor():
|
|
||||||
|
|
||||||
def __init__(self, x, labels):
|
class LabelTensor(torch.Tensor):
|
||||||
|
"""Torch tensor with a label for any column."""
|
||||||
|
|
||||||
if len(labels) != x.shape[1]:
|
|
||||||
print(len(labels), x.shape[1])
|
|
||||||
raise ValueError
|
|
||||||
self.__labels = labels
|
|
||||||
self.tensor = x
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
if isinstance(key, (tuple, list)):
|
|
||||||
indeces = [self.labels.index(k) for k in key]
|
|
||||||
return LabelTensor(self.tensor[:, indeces], [self.labels[idx] for idx in indeces])
|
|
||||||
if key in self.labels:
|
|
||||||
return self.tensor[:, self.labels.index(key)]
|
|
||||||
else:
|
|
||||||
return self.tensor.__getitem__(key)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self.tensor
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return '{}\n {}\n'.format(self.labels, self.tensor)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
return self.tensor.shape
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.tensor.dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self.tensor.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def labels(self):
|
|
||||||
return self.__labels
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hstack(labeltensor_list):
|
def __new__(cls, x, labels, *args, **kwargs):
|
||||||
concatenated_tensor = torch.cat([lt.tensor for lt in labeltensor_list], axis=1)
|
return super().__new__(cls, x, *args, **kwargs)
|
||||||
concatenated_label = sum([lt.labels for lt in labeltensor_list], [])
|
|
||||||
return LabelTensor(concatenated_tensor, concatenated_label)
|
|
||||||
|
|
||||||
|
def __init__(self, x, labels):
|
||||||
|
'''
|
||||||
|
Construct a `LabelTensor` by passing a tensor and a list of column
|
||||||
|
labels. Such labels uniquely identify the columns of the tensor,
|
||||||
|
allowing for an easier manipulation.
|
||||||
|
|
||||||
|
:param torch.Tensor x: the data tensor.
|
||||||
|
:param labels: the labels of the columns.
|
||||||
|
:type labels: str or iterable(str)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
:Example:
|
||||||
import numpy as np
|
>>> from pina import LabelTensor
|
||||||
a = np.random.uniform(size=(20, 3))
|
>>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c'])
|
||||||
a = np.random.uniform(size=(20, 3))
|
>>> tensor
|
||||||
p = torch.from_numpy(a)
|
tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01],
|
||||||
t = LabelTensor(p, labels=['u', 'p', 't'])
|
[9.2392e-01, 8.2065e-01, 4.1986e-04],
|
||||||
print(t)
|
[8.9266e-01, 5.5446e-01, 6.3500e-01],
|
||||||
print(t['u'])
|
...,
|
||||||
t *= 2
|
[5.8194e-01, 9.4268e-01, 4.1841e-01],
|
||||||
print(t['u'])
|
[1.0246e-01, 9.5179e-01, 3.7043e-02],
|
||||||
print(t[:, 0])
|
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
|
||||||
|
>>> tensor.extract('a')
|
||||||
|
tensor([[0.0671],
|
||||||
|
[0.9239],
|
||||||
|
[0.8927],
|
||||||
|
...,
|
||||||
|
[0.5819],
|
||||||
|
[0.1025],
|
||||||
|
[0.9615]])
|
||||||
|
>>> tensor.extract(['a', 'b'])
|
||||||
|
tensor([[0.0671, 0.4889],
|
||||||
|
[0.9239, 0.8207],
|
||||||
|
[0.8927, 0.5545],
|
||||||
|
...,
|
||||||
|
[0.5819, 0.9427],
|
||||||
|
[0.1025, 0.9518],
|
||||||
|
[0.9615, 0.8066]])
|
||||||
|
>>> tensor.extract(['b', 'a'])
|
||||||
|
tensor([[0.4889, 0.0671],
|
||||||
|
[0.8207, 0.9239],
|
||||||
|
[0.5545, 0.8927],
|
||||||
|
...,
|
||||||
|
[0.9427, 0.5819],
|
||||||
|
[0.9518, 0.1025],
|
||||||
|
[0.8066, 0.9615]])
|
||||||
|
'''
|
||||||
|
|
||||||
|
if isinstance(labels, str):
|
||||||
|
labels = [labels]
|
||||||
|
|
||||||
|
if len(labels) != x.shape[1]:
|
||||||
|
raise ValueError(
|
||||||
|
'the tensor has not the same number of columns of '
|
||||||
|
'the passed labels.'
|
||||||
|
)
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
def clone(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Clone the LabelTensor. For more details, see
|
||||||
|
:meth:`torch.Tensor.clone`.
|
||||||
|
|
||||||
|
:return: a copy of the tensor
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
return LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Performs Tensor dtype and/or device conversion. For more details, see
|
||||||
|
:meth:`torch.Tensor.to`.
|
||||||
|
"""
|
||||||
|
new_obj = LabelTensor([], self.labels)
|
||||||
|
tempTensor = super().to(*args, **kwargs)
|
||||||
|
new_obj.data = tempTensor.data
|
||||||
|
new_obj.requires_grad = tempTensor.requires_grad
|
||||||
|
return new_obj
|
||||||
|
|
||||||
|
def extract(self, label_to_extract):
|
||||||
|
"""
|
||||||
|
Extract the subset of the original tensor by returning all the columns
|
||||||
|
corresponding to the passed `label_to_extract`.
|
||||||
|
|
||||||
|
:param label_to_extract: the label(s) to extract.
|
||||||
|
:type label_to_extract: str or iterable(str)
|
||||||
|
:raises TypeError: labels are not str
|
||||||
|
:raises ValueError: label to extract is not in the labels list
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(label_to_extract, str):
|
||||||
|
label_to_extract = [label_to_extract]
|
||||||
|
elif isinstance(label_to_extract, (tuple, list)): # TODO
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'`label_to_extract` should be a str, or a str iterator')
|
||||||
|
|
||||||
|
try:
|
||||||
|
indeces = [self.labels.index(f) for f in label_to_extract]
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError('`label_to_extract` not in the labels list')
|
||||||
|
|
||||||
|
extracted_tensor = LabelTensor(
|
||||||
|
self[:, indeces],
|
||||||
|
[self.labels[idx] for idx in indeces]
|
||||||
|
)
|
||||||
|
|
||||||
|
return extracted_tensor
|
||||||
|
|||||||
61
tests/test_label_tensor.py
Normal file
61
tests/test_label_tensor.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina import LabelTensor
|
||||||
|
|
||||||
|
data = torch.rand((20, 3))
|
||||||
|
labels = ['a', 'b', 'c']
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
LabelTensor(data, labels)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_constructor():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LabelTensor(data, ['a', 'b'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_labels():
|
||||||
|
tensor = LabelTensor(data, labels)
|
||||||
|
assert isinstance(tensor, torch.Tensor)
|
||||||
|
assert tensor.labels == labels
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract():
|
||||||
|
label_to_extract = ['a', 'c']
|
||||||
|
tensor = LabelTensor(data, labels)
|
||||||
|
new = tensor.extract(label_to_extract)
|
||||||
|
assert new.labels == label_to_extract
|
||||||
|
assert new.shape[1] == len(label_to_extract)
|
||||||
|
assert torch.all(torch.isclose(data[:, 0::2], new))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_onelabel():
|
||||||
|
label_to_extract = ['a']
|
||||||
|
tensor = LabelTensor(data, labels)
|
||||||
|
new = tensor.extract(label_to_extract)
|
||||||
|
assert new.ndim == 2
|
||||||
|
assert new.labels == label_to_extract
|
||||||
|
assert new.shape[1] == len(label_to_extract)
|
||||||
|
assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_extract():
|
||||||
|
label_to_extract = ['a', 'cc']
|
||||||
|
tensor = LabelTensor(data, labels)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tensor.extract(label_to_extract)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_order():
|
||||||
|
label_to_extract = ['c', 'a']
|
||||||
|
tensor = LabelTensor(data, labels)
|
||||||
|
new = tensor.extract(label_to_extract)
|
||||||
|
expected = torch.cat(
|
||||||
|
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
|
||||||
|
dim=1)
|
||||||
|
print(expected)
|
||||||
|
assert new.labels == label_to_extract
|
||||||
|
assert new.shape[1] == len(label_to_extract)
|
||||||
|
assert torch.all(torch.isclose(expected, new))
|
||||||
Reference in New Issue
Block a user