Documentation for v0.1 version (#199)
* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
3f9305d475
commit
8b7b61b3bd
@@ -17,9 +17,9 @@ class LabelTensor(torch.Tensor):
|
||||
labels. Such labels uniquely identify the columns of the tensor,
|
||||
allowing for an easier manipulation.
|
||||
|
||||
:param torch.Tensor x: the data tensor.
|
||||
:param labels: the labels of the columns.
|
||||
:type labels: str or iterable(str)
|
||||
:param torch.Tensor x: The data tensor.
|
||||
:param labels: The labels of the columns.
|
||||
:type labels: str | list(str) | tuple(str)
|
||||
|
||||
:Example:
|
||||
>>> from pina import LabelTensor
|
||||
@@ -72,10 +72,8 @@ class LabelTensor(torch.Tensor):
|
||||
labels = [labels]
|
||||
|
||||
if len(labels) != x.shape[-1]:
|
||||
raise ValueError(
|
||||
'the tensor has not the same number of columns of '
|
||||
'the passed labels.'
|
||||
)
|
||||
raise ValueError('the tensor has not the same number of columns of '
|
||||
'the passed labels.')
|
||||
self._labels = labels
|
||||
|
||||
@property
|
||||
@@ -90,11 +88,10 @@ class LabelTensor(torch.Tensor):
|
||||
@labels.setter
|
||||
def labels(self, labels):
|
||||
if len(labels) != self.shape[self.ndim - 1]: # small check
|
||||
raise ValueError(
|
||||
'the tensor has not the same number of columns of '
|
||||
'the passed labels.')
|
||||
raise ValueError('The tensor has not the same number of columns of '
|
||||
'the passed labels.')
|
||||
|
||||
self._labels = labels # assign the label
|
||||
self._labels = labels # assign the label
|
||||
|
||||
@staticmethod
|
||||
def vstack(label_tensors):
|
||||
@@ -123,7 +120,7 @@ class LabelTensor(torch.Tensor):
|
||||
Clone the LabelTensor. For more details, see
|
||||
:meth:`torch.Tensor.clone`.
|
||||
|
||||
:return: a copy of the tensor
|
||||
:return: A copy of the tensor.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# # used before merging
|
||||
@@ -173,12 +170,12 @@ class LabelTensor(torch.Tensor):
|
||||
def extract(self, label_to_extract):
|
||||
"""
|
||||
Extract the subset of the original tensor by returning all the columns
|
||||
corresponding to the passed `label_to_extract`.
|
||||
corresponding to the passed ``label_to_extract``.
|
||||
|
||||
:param label_to_extract: the label(s) to extract.
|
||||
:type label_to_extract: str or iterable(str)
|
||||
:raises TypeError: labels are not str
|
||||
:raises ValueError: label to extract is not in the labels list
|
||||
:param label_to_extract: The label(s) to extract.
|
||||
:type label_to_extract: str | list(str) | tuple(str)
|
||||
:raises TypeError: Labels are not ``str``.
|
||||
:raises ValueError: Label to extract is not in the labels ``list``.
|
||||
"""
|
||||
|
||||
if isinstance(label_to_extract, str):
|
||||
@@ -211,7 +208,7 @@ class LabelTensor(torch.Tensor):
|
||||
return detached
|
||||
|
||||
|
||||
def requires_grad_(self, mode = True) -> Tensor:
|
||||
def requires_grad_(self, mode = True):
|
||||
lt = super().requires_grad_(mode)
|
||||
lt.labels = self.labels
|
||||
return lt
|
||||
@@ -220,9 +217,9 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
Return a copy of the merged tensors.
|
||||
|
||||
:param LabelTensor lt: the tensor to merge.
|
||||
:param LabelTensor lt: The tensor to merge.
|
||||
:param str mode: {'std', 'first', 'cross'}
|
||||
:return: the merged tensors
|
||||
:return: The merged tensors.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if set(self.labels).intersection(lt.labels):
|
||||
@@ -239,12 +236,9 @@ class LabelTensor(torch.Tensor):
|
||||
n1 = tensor1.shape[0]
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
tensor1 = LabelTensor(
|
||||
tensor1.repeat(n2, 1),
|
||||
labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(
|
||||
tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
new_tensor = torch.cat((tensor1, tensor2), dim=1)
|
||||
|
||||
new_tensor = new_tensor.as_subclass(LabelTensor)
|
||||
@@ -290,7 +284,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
def __len__(self) -> int:
|
||||
return super().__len__()
|
||||
|
||||
|
||||
def __str__(self):
|
||||
if hasattr(self, 'labels'):
|
||||
s = f'labels({str(self.labels)})\n'
|
||||
|
||||
Reference in New Issue
Block a user