diff --git a/docs/source/_rst/code.rst b/docs/source/_rst/code.rst new file mode 100644 index 0000000..0916b80 --- /dev/null +++ b/docs/source/_rst/code.rst @@ -0,0 +1,7 @@ +Code Documentation +================== + +.. toctree:: + :maxdepth: 3 + + LabelTensor diff --git a/docs/source/_rst/label_tensor.rst b/docs/source/_rst/label_tensor.rst new file mode 100644 index 0000000..976e839 --- /dev/null +++ b/docs/source/_rst/label_tensor.rst @@ -0,0 +1,12 @@ +LabelTensor +=========== +.. currentmodule:: pina.label_tensor + +.. automodule:: pina.label_tensor + +.. autoclass:: LabelTensor + :members: + :private-members: + :undoc-members: + :show-inheritance: + :noindex: diff --git a/docs/source/conf.py b/docs/source/conf.py index 0ca6208..b2dbc70 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,12 +47,14 @@ extensions = [ 'sphinx.ext.ifconfig', 'sphinx.ext.mathjax', ] +autosummary_generate = True intersphinx_mapping = { - 'python': ('http://docs.python.org/2', None), 'numpy': - ('http://docs.scipy.org/doc/numpy/', None), 'scipy': - ('http://docs.scipy.org/doc/scipy/reference/', None), 'matplotlib': - ('http://matplotlib.sourceforge.net/', None) + 'python': ('http://docs.python.org/2', None), + 'numpy': ('http://docs.scipy.org/doc/numpy/', None), + 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), + 'matplotlib': ('http://matplotlib.sourceforge.net/', None), + 'torch': ('https://pytorch.org/docs/stable/', None) } # Add any paths that contain templates here, relative to this directory. diff --git a/docs/source/index.rst b/docs/source/index.rst index 92a64cb..d2c7b6b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,7 @@ solve problems in a continuous and nonlinear settings. :caption: Package Documentation: Installation <_rst/installation> + API <_rst/code> Contributing <_rst/contributing> License @@ -30,7 +31,7 @@ solve problems in a continuous and nonlinear settings. .. ........................................................................................ .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :numbered: :caption: Tutorials: diff --git a/pina/label_tensor.py b/pina/label_tensor.py index c9e219b..2b3b92c 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -1,64 +1,119 @@ +""" Module for LabelTensor """ import torch -class LabelTensor(): - def __init__(self, x, labels): - - - if len(labels) != x.shape[1]: - print(len(labels), x.shape[1]) - raise ValueError - self.__labels = labels - self.tensor = x - - def __getitem__(self, key): - if isinstance(key, (tuple, list)): - indeces = [self.labels.index(k) for k in key] - return LabelTensor(self.tensor[:, indeces], [self.labels[idx] for idx in indeces]) - if key in self.labels: - return self.tensor[:, self.labels.index(key)] - else: - return self.tensor.__getitem__(key) - - def __repr__(self): - return self.tensor - - def __str__(self): - return '{}\n {}\n'.format(self.labels, self.tensor) - - @property - def shape(self): - return self.tensor.shape - - @property - def dtype(self): - return self.tensor.dtype - - @property - def device(self): - return self.tensor.device - - @property - def labels(self): - return self.__labels +class LabelTensor(torch.Tensor): + """Torch tensor with a label for any column.""" @staticmethod - def hstack(labeltensor_list): - concatenated_tensor = torch.cat([lt.tensor for lt in labeltensor_list], axis=1) - concatenated_label = sum([lt.labels for lt in labeltensor_list], []) - return LabelTensor(concatenated_tensor, concatenated_label) + def __new__(cls, x, labels, *args, **kwargs): + return super().__new__(cls, x, *args, **kwargs) + def __init__(self, x, labels): + ''' + Construct a `LabelTensor` by passing a tensor and a list of column + labels. Such labels uniquely identify the columns of the tensor, + allowing for an easier manipulation. + :param torch.Tensor x: the data tensor. + :param labels: the labels of the columns. + :type labels: str or iterable(str) -if __name__ == "__main__": - import numpy as np - a = np.random.uniform(size=(20, 3)) - a = np.random.uniform(size=(20, 3)) - p = torch.from_numpy(a) - t = LabelTensor(p, labels=['u', 'p', 't']) - print(t) - print(t['u']) - t *= 2 - print(t['u']) - print(t[:, 0]) + :Example: + >>> from pina import LabelTensor + >>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c']) + >>> tensor + tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01], + [9.2392e-01, 8.2065e-01, 4.1986e-04], + [8.9266e-01, 5.5446e-01, 6.3500e-01], + ..., + [5.8194e-01, 9.4268e-01, 4.1841e-01], + [1.0246e-01, 9.5179e-01, 3.7043e-02], + [9.6150e-01, 8.0656e-01, 8.3824e-01]]) + >>> tensor.extract('a') + tensor([[0.0671], + [0.9239], + [0.8927], + ..., + [0.5819], + [0.1025], + [0.9615]]) + >>> tensor.extract(['a', 'b']) + tensor([[0.0671, 0.4889], + [0.9239, 0.8207], + [0.8927, 0.5545], + ..., + [0.5819, 0.9427], + [0.1025, 0.9518], + [0.9615, 0.8066]]) + >>> tensor.extract(['b', 'a']) + tensor([[0.4889, 0.0671], + [0.8207, 0.9239], + [0.5545, 0.8927], + ..., + [0.9427, 0.5819], + [0.9518, 0.1025], + [0.8066, 0.9615]]) + ''' + if isinstance(labels, str): + labels = [labels] + + if len(labels) != x.shape[1]: + raise ValueError( + 'the tensor has not the same number of columns of ' + 'the passed labels.' + ) + self.labels = labels + + def clone(self, *args, **kwargs): + """ + Clone the LabelTensor. For more details, see + :meth:`torch.Tensor.clone`. + + :return: a copy of the tensor + :rtype: LabelTensor + """ + return LabelTensor(super().clone(*args, **kwargs), self.labels) + + def to(self, *args, **kwargs): + """ + Performs Tensor dtype and/or device conversion. For more details, see + :meth:`torch.Tensor.to`. + """ + new_obj = LabelTensor([], self.labels) + tempTensor = super().to(*args, **kwargs) + new_obj.data = tempTensor.data + new_obj.requires_grad = tempTensor.requires_grad + return new_obj + + def extract(self, label_to_extract): + """ + Extract the subset of the original tensor by returning all the columns + corresponding to the passed `label_to_extract`. + + :param label_to_extract: the label(s) to extract. + :type label_to_extract: str or iterable(str) + :raises TypeError: labels are not str + :raises ValueError: label to extract is not in the labels list + """ + + if isinstance(label_to_extract, str): + label_to_extract = [label_to_extract] + elif isinstance(label_to_extract, (tuple, list)): # TODO + pass + else: + raise TypeError( + '`label_to_extract` should be a str, or a str iterator') + + try: + indeces = [self.labels.index(f) for f in label_to_extract] + except ValueError: + raise ValueError('`label_to_extract` not in the labels list') + + extracted_tensor = LabelTensor( + self[:, indeces], + [self.labels[idx] for idx in indeces] + ) + + return extracted_tensor diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor.py new file mode 100644 index 0000000..dbf0d44 --- /dev/null +++ b/tests/test_label_tensor.py @@ -0,0 +1,61 @@ +import torch +import pytest + +from pina import LabelTensor + +data = torch.rand((20, 3)) +labels = ['a', 'b', 'c'] + + +def test_constructor(): + LabelTensor(data, labels) + + +def test_wrong_constructor(): + with pytest.raises(ValueError): + LabelTensor(data, ['a', 'b']) + + +def test_labels(): + tensor = LabelTensor(data, labels) + assert isinstance(tensor, torch.Tensor) + assert tensor.labels == labels + + +def test_extract(): + label_to_extract = ['a', 'c'] + tensor = LabelTensor(data, labels) + new = tensor.extract(label_to_extract) + assert new.labels == label_to_extract + assert new.shape[1] == len(label_to_extract) + assert torch.all(torch.isclose(data[:, 0::2], new)) + + +def test_extract_onelabel(): + label_to_extract = ['a'] + tensor = LabelTensor(data, labels) + new = tensor.extract(label_to_extract) + assert new.ndim == 2 + assert new.labels == label_to_extract + assert new.shape[1] == len(label_to_extract) + assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new)) + + +def test_wrong_extract(): + label_to_extract = ['a', 'cc'] + tensor = LabelTensor(data, labels) + with pytest.raises(ValueError): + tensor.extract(label_to_extract) + + +def test_extract_order(): + label_to_extract = ['c', 'a'] + tensor = LabelTensor(data, labels) + new = tensor.extract(label_to_extract) + expected = torch.cat( + (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)), + dim=1) + print(expected) + assert new.labels == label_to_extract + assert new.shape[1] == len(label_to_extract) + assert torch.all(torch.isclose(expected, new))