Doc LabelTensor
This commit is contained in:
committed by
Nicola Demo
parent
635e3b3a75
commit
9e40b58339
@@ -6,10 +6,23 @@ from torch import Tensor
|
|||||||
|
|
||||||
|
|
||||||
class LabelTensor(torch.Tensor):
|
class LabelTensor(torch.Tensor):
|
||||||
"""Torch tensor with a label for any column."""
|
"""
|
||||||
|
Extension of the :class:`torch.Tensor` class that includes labels for
|
||||||
|
each dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, x, labels, *args, **kwargs):
|
def __new__(cls, x, labels, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a new instance of the :class:`LabelTensor` class.
|
||||||
|
|
||||||
|
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
|
||||||
|
:class:`LabelTensor`.
|
||||||
|
:param labels: Labels to assign to the tensor.
|
||||||
|
:type labels: str | list(str) | dict
|
||||||
|
:return: The instance of the :class:`LabelTensor` class.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(x, LabelTensor):
|
if isinstance(x, LabelTensor):
|
||||||
return x
|
return x
|
||||||
@@ -18,47 +31,47 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def tensor(self):
|
def tensor(self):
|
||||||
"""
|
"""
|
||||||
Give the tensor part of the LabelTensor.
|
Give the tensor part of the :class:`LabelTensor` object.
|
||||||
|
|
||||||
:return: tensor part of the LabelTensor
|
:return: tensor part of the :class:`LabelTensor`.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.as_subclass(Tensor)
|
return self.as_subclass(Tensor)
|
||||||
|
|
||||||
def __init__(self, x, labels):
|
def __init__(self, x, labels):
|
||||||
"""
|
"""
|
||||||
Construct a `LabelTensor` by passing a dict of the labels
|
Construct a :class:`LabelTensor` by passing a dict of the labels and a
|
||||||
|
:class:`torch.Tensor`. Internally, the initialization method will store
|
||||||
|
check the compatibility of the labels with the tensor shape.
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> from pina import LabelTensor
|
>>> from pina import LabelTensor
|
||||||
>>> tensor = LabelTensor(
|
>>> tensor = LabelTensor(
|
||||||
>>> torch.rand((2000, 3)),
|
>>> torch.rand((2000, 3)),
|
||||||
{1: {"name": "space"['a', 'b', 'c'])
|
... {1: {"name": "space", "dof": ['a', 'b', 'c'])
|
||||||
|
>>> tensor = LabelTensor(
|
||||||
|
>>> torch.rand((2000, 3)),
|
||||||
|
... ["a", "b", "c"])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# Avoid unused argument warning. x is not used in the constructor
|
||||||
|
# of the parent class.
|
||||||
|
# pylint: disable=unused-argument
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
else:
|
else:
|
||||||
self._labels = {}
|
self._labels = {}
|
||||||
|
|
||||||
@property
|
|
||||||
def labels(self):
|
|
||||||
"""Property decorator for labels
|
|
||||||
|
|
||||||
:return: labels of self
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
if self.ndim - 1 in self._labels:
|
|
||||||
return self._labels[self.ndim - 1]["dof"]
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_labels(self):
|
def full_labels(self):
|
||||||
"""Property decorator for labels
|
"""
|
||||||
|
Gives the full labels of the tensor, even for the dimensions that are
|
||||||
|
not labeled.
|
||||||
|
|
||||||
:return: labels of self
|
:return: The full labels of the tensor
|
||||||
:rtype: list
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
to_return_dict = {}
|
to_return_dict = {}
|
||||||
shape_tensor = self.shape
|
shape_tensor = self.shape
|
||||||
@@ -71,21 +84,40 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def stored_labels(self):
|
def stored_labels(self):
|
||||||
"""Property decorator for labels
|
"""
|
||||||
|
Gives the labels stored inside the instance.
|
||||||
|
|
||||||
:return: labels of self
|
:return: The labels stored inside the instance.
|
||||||
:rtype: list
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
return self._labels
|
return self._labels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
"""
|
||||||
|
Give the labels of the last dimension of the instance.
|
||||||
|
|
||||||
|
:return: labels of last dimension
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
if self.ndim - 1 in self._labels:
|
||||||
|
return self._labels[self.ndim - 1]["dof"]
|
||||||
|
return None
|
||||||
|
|
||||||
@labels.setter
|
@labels.setter
|
||||||
def labels(self, labels):
|
def labels(self, labels):
|
||||||
""" "
|
"""
|
||||||
Set properly the parameter _labels
|
Set the parameter ``_labels`` by checking the type of the input labels
|
||||||
|
and handling it accordingly. The following types are accepted:
|
||||||
|
|
||||||
|
- **list**: The list of labels is assigned to the last dimension.
|
||||||
|
- **dict**: The dictionary of labels is assigned to the tensor.
|
||||||
|
- **str**: The string is assigned to the last dimension.
|
||||||
|
|
||||||
:param labels: Labels to assign to the class variable _labels.
|
:param labels: Labels to assign to the class variable _labels.
|
||||||
:type: labels: str | list(str) | dict
|
:type labels: str | list(str) | dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not hasattr(self, "_labels"):
|
if not hasattr(self, "_labels"):
|
||||||
self._labels = {}
|
self._labels = {}
|
||||||
if isinstance(labels, dict):
|
if isinstance(labels, dict):
|
||||||
@@ -100,14 +132,14 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def _init_labels_from_dict(self, labels: dict):
|
def _init_labels_from_dict(self, labels: dict):
|
||||||
"""
|
"""
|
||||||
Update the internal label representation according to the values
|
Store the internal label representation according to the values
|
||||||
passed as input.
|
passed as input.
|
||||||
|
|
||||||
:param labels: The label(s) to update.
|
:param dict labels: The label(s) to update.
|
||||||
:type labels: dict
|
|
||||||
:raises ValueError: If the dof list contains duplicates or the number of
|
:raises ValueError: If the dof list contains duplicates or the number of
|
||||||
dof does not match the tensor shape.
|
dof does not match the tensor shape.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensor_shape = self.shape
|
tensor_shape = self.shape
|
||||||
|
|
||||||
def validate_dof(dof_list, dim_size: int):
|
def validate_dof(dof_list, dim_size: int):
|
||||||
@@ -151,12 +183,13 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def _init_labels_from_list(self, labels):
|
def _init_labels_from_list(self, labels):
|
||||||
"""
|
"""
|
||||||
Given a list of dof, this method update the internal label
|
Given a ``list`` of dof, this method update the internal label
|
||||||
representation
|
representation by assigning the dof to the last dimension.
|
||||||
|
|
||||||
:param labels: The label(s) to update.
|
:param labels: The label(s) to update.
|
||||||
:type labels: list
|
:type labels: list
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create a dict with labels
|
# Create a dict with labels
|
||||||
last_dim_labels = {
|
last_dim_labels = {
|
||||||
self.ndim - 1: {"dof": labels, "name": self.ndim - 1}
|
self.ndim - 1: {"dof": labels, "name": self.ndim - 1}
|
||||||
@@ -165,12 +198,19 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def extract(self, labels_to_extract):
|
def extract(self, labels_to_extract):
|
||||||
"""
|
"""
|
||||||
Extract the subset of the original tensor by returning all the columns
|
Extract the subset of the original tensor by returning all the positions
|
||||||
corresponding to the passed ``label_to_extract``.
|
corresponding to the passed ``label_to_extract``.
|
||||||
|
|
||||||
:param labels_to_extract: The label(s) to extract.
|
:param labels_to_extract: The label(s) to extract. If a single label or
|
||||||
:type labels_to_extract: str | list(str) | tuple(str)
|
a list of labels is passed, the last dimension is considered.
|
||||||
:raises TypeError: Labels are not ``str``.
|
If a dictionary is passed, the keys are the dimension names and the
|
||||||
|
values are the labels to extract.
|
||||||
|
:type labels_to_extract: str | list(str) | tuple(str) | dict
|
||||||
|
:return: The extracted tensor with the updated labels.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
|
||||||
|
:raises TypeError: Labels are not ``str``, ``list(str)`` or ``dict``
|
||||||
|
properly setted.
|
||||||
:raises ValueError: Label to extract is not in the labels ``list``.
|
:raises ValueError: Label to extract is not in the labels ``list``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -231,8 +271,12 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""
|
"""
|
||||||
returns a string with the representation of the class
|
The string representation of the :class:`LabelTensor`.
|
||||||
|
|
||||||
|
:return: String representation of the :class:`LabelTensor` instance.
|
||||||
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
s = ""
|
s = ""
|
||||||
for key, value in self._labels.items():
|
for key, value in self._labels.items():
|
||||||
s += f"{key}: {value}\n"
|
s += f"{key}: {value}\n"
|
||||||
@@ -243,18 +287,20 @@ class LabelTensor(torch.Tensor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def cat(tensors, dim=0):
|
def cat(tensors, dim=0):
|
||||||
"""
|
"""
|
||||||
Stack a list of tensors. For example, given a tensor `a` of shape
|
Concatenate a list of tensors along a specified dimension. For more
|
||||||
`(n,m,dof)` and a tensor `b` of dimension `(n',m,dof)`
|
details, see :meth:`torch.cat`.
|
||||||
the resulting tensor is of shape `(n+n',m,dof)`
|
|
||||||
|
|
||||||
:param tensors: tensors to concatenate
|
:param list(LabelTensor) tensors: :class:`LabelTensor` instances to
|
||||||
:type tensors: list of LabelTensor
|
concatenate
|
||||||
:param dim: dimensions on which you want to perform the operation
|
:param int dim: dimensions on which you want to perform the operation
|
||||||
(default is 0)
|
(default is 0)
|
||||||
:type dim: int
|
:return: A new :class:`LabelTensor' instance obtained by concatenating
|
||||||
|
the input instances, with the updated labels.
|
||||||
|
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
:raises ValueError: either number dof or dimensions names differ
|
:raises ValueError: either number dof or dimensions names differ
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not tensors:
|
if not tensors:
|
||||||
return [] # Handle empty list
|
return [] # Handle empty list
|
||||||
if len(tensors) == 1:
|
if len(tensors) == 1:
|
||||||
@@ -295,15 +341,16 @@ class LabelTensor(torch.Tensor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def stack(tensors):
|
def stack(tensors):
|
||||||
"""
|
"""
|
||||||
Stacks a list of tensors along a new dimension.
|
Stacks a list of tensors along a new dimension. For more details, see
|
||||||
|
:meth:`torch.stack`.
|
||||||
|
|
||||||
:param tensors: A list of tensors to stack. All tensors must have the
|
:param list(LabelTensor) tensors: A list of tensors to stack.
|
||||||
same shape.
|
All tensors must have the same shape.
|
||||||
:type tensors: list of LabelTensor
|
:return: A new :class:`LabelTensor` instance obtained by stacking the
|
||||||
:return: A new tensor obtained by stacking the input tensors,
|
input tensors, with the updated labels.
|
||||||
with the updated labels.
|
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Perform stacking in torch
|
# Perform stacking in torch
|
||||||
new_tensor = torch.stack(tensors)
|
new_tensor = torch.stack(tensors)
|
||||||
|
|
||||||
@@ -315,17 +362,17 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def requires_grad_(self, mode=True):
|
def requires_grad_(self, mode=True):
|
||||||
"""
|
"""
|
||||||
Override the requires_grad_ method to update the labels in the new
|
Override the requires_grad_ method to handle the labels in the new
|
||||||
tensor.
|
tensor. For more details, see :meth:`torch.Tensor.requires_grad_`.
|
||||||
|
|
||||||
:param mode: A boolean value indicating whether the tensor should track
|
:param bool mode: A boolean value indicating whether the tensor should
|
||||||
gradients.If `True`, the tensor will track gradients; if `
|
track gradients.If `True`, the tensor will track gradients;
|
||||||
False`, it will not.
|
if `False`, it will not.
|
||||||
:type mode: bool, optional (default is `True`)
|
:return: The :class:`LabelTensor` itself with the updated
|
||||||
:return: The tensor itself with the updated `requires_grad` state and
|
`requires_grad` state and retained labels.
|
||||||
retained labels.
|
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lt = super().requires_grad_(mode)
|
lt = super().requires_grad_(mode)
|
||||||
lt._labels = self._labels
|
lt._labels = self._labels
|
||||||
return lt
|
return lt
|
||||||
@@ -333,30 +380,39 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
"""
|
"""
|
||||||
Give the dtype of the tensor.
|
Give the ``dtype`` of the tensor. For more details, see
|
||||||
|
:meth:`torch.dtype`.
|
||||||
|
|
||||||
:return: dtype of the tensor
|
:return: dtype of the tensor
|
||||||
:rtype: torch.dtype
|
:rtype: torch.dtype
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return super().dtype
|
return super().dtype
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Performs Tensor dtype and/or device conversion. For more details, see
|
Performs Tensor dtype and/or device conversion. For more details, see
|
||||||
:meth:`torch.Tensor.to`.
|
:meth:`torch.Tensor.to`.
|
||||||
|
|
||||||
|
:return: A new :class:`LabelTensor` instance with the updated dtype
|
||||||
|
and/or device and retained labels.
|
||||||
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lt = super().to(*args, **kwargs)
|
lt = super().to(*args, **kwargs)
|
||||||
lt._labels = self._labels
|
lt._labels = self._labels
|
||||||
return lt
|
return lt
|
||||||
|
|
||||||
def clone(self, *args, **kwargs):
|
def clone(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Clone the LabelTensor. For more details, see
|
Clone the :class:`LabelTensor`. For more details, see
|
||||||
:meth:`torch.Tensor.clone`.
|
:meth:`torch.Tensor.clone`.
|
||||||
|
|
||||||
:return: A copy of the tensor.
|
:return: A new :class:`LabelTensor` instance with the same data and
|
||||||
|
labels but allocated in a different memory location.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
out = LabelTensor(
|
out = LabelTensor(
|
||||||
super().clone(*args, **kwargs), deepcopy(self._labels)
|
super().clone(*args, **kwargs), deepcopy(self._labels)
|
||||||
)
|
)
|
||||||
@@ -366,21 +422,23 @@ class LabelTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
Appends a given tensor to the current tensor along the last dimension.
|
Appends a given tensor to the current tensor along the last dimension.
|
||||||
|
|
||||||
This method allows for two types of appending operations:
|
This method supports two types of appending operations:
|
||||||
1. **Standard append** ("std"): Concatenates the tensors along the
|
1. **Standard append** ("std"): Concatenates the input tensor with the
|
||||||
last dimension.
|
current tensor along the last dimension.
|
||||||
2. **Cross append** ("cross"): Repeats the current tensor and the new
|
2. **Cross append** ("cross"): Creates a cross-product of the current
|
||||||
tensor in a cross-product manner, then concatenates them.
|
tensor and the input tensor by repeating them in a cross-product
|
||||||
|
fashion, then concatenates the result along the last dimension.
|
||||||
|
|
||||||
:param LabelTensor tensor: The tensor to append.
|
:param tensor: The tensor to append to the current tensor.
|
||||||
|
:type tensor: LabelTensor
|
||||||
:param mode: The append mode to use. Defaults to "std".
|
:param mode: The append mode to use. Defaults to "std".
|
||||||
:type mode: str, optional
|
:type mode: str, optional
|
||||||
:return: The new tensor obtained by appending the input tensor
|
:return: A new `LabelTensor` obtained by appending the input tensor.
|
||||||
(either 'std' or 'cross').
|
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
|
|
||||||
:raises ValueError: If the mode is not "std" or "cross".
|
:raises ValueError: If the mode is not "std" or "cross".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if mode == "std":
|
if mode == "std":
|
||||||
# Call cat on last dimension
|
# Call cat on last dimension
|
||||||
new_label_tensor = LabelTensor.cat(
|
new_label_tensor = LabelTensor.cat(
|
||||||
@@ -406,14 +464,15 @@ class LabelTensor(torch.Tensor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def vstack(label_tensors):
|
def vstack(label_tensors):
|
||||||
"""
|
"""
|
||||||
Stack tensors vertically. For more details, see
|
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
|
||||||
:meth:`torch.vstack`.
|
|
||||||
|
|
||||||
:param list(LabelTensor) label_tensors: the tensors to stack. They need
|
:param list(LabelTensor) label_tensors: The :class:`LabelTensor`
|
||||||
to have equal labels.
|
instances to stack. They need to have equal labels.
|
||||||
:return: the stacked tensor
|
:return: A new :class:`LabelTensor` instance obtained by stacking the
|
||||||
|
input tensors vertically.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return LabelTensor.cat(label_tensors, dim=0)
|
return LabelTensor.cat(label_tensors, dim=0)
|
||||||
|
|
||||||
# This method is used to update labels
|
# This method is used to update labels
|
||||||
@@ -421,13 +480,17 @@ class LabelTensor(torch.Tensor):
|
|||||||
self, old_labels, to_update_labels, index, dim, to_update_dim
|
self, old_labels, to_update_labels, index, dim, to_update_dim
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Update the labels of the tensor by selecting only the labels
|
Update the labels of the tensor based on the index (or list of indices).
|
||||||
:param old_labels: labels from which retrieve data
|
|
||||||
:param to_update_labels: labels to update
|
:param dict old_labels: Labels from which retrieve data.
|
||||||
:param index: index of dof to retain
|
:param dict to_update_labels: Labels to update.
|
||||||
:param dim: label index
|
:param index: Index of dof to retain.
|
||||||
:return:
|
:type index: int | slice | list | torch.Tensor]
|
||||||
|
:param int dim: The dimension to update.
|
||||||
|
|
||||||
|
:raises: ValueError: If the index type is not supported.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
old_dof = old_labels[to_update_dim]["dof"]
|
old_dof = old_labels[to_update_dim]["dof"]
|
||||||
label_name = old_labels[dim]["name"]
|
label_name = old_labels[dim]["name"]
|
||||||
# Handle slicing
|
# Handle slicing
|
||||||
@@ -460,16 +523,22 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
""" "
|
""" "
|
||||||
Override the __getitem__ method to handle the labels of the tensor.
|
Override the __getitem__ method to handle the labels of the
|
||||||
Perform the __getitem__ operation on the tensor and update the labels.
|
:class:`LabelTensor` instance. It first performs __getitem__ operation
|
||||||
|
on the :class:`torch.Tensor` part of the instance, then updates the
|
||||||
|
labels based on the index.
|
||||||
|
|
||||||
:param index: The index used to access the item
|
:param index: The index used to access the item
|
||||||
:type index: Union[int, str, tuple, list]
|
:type index: int | str | tuple | list | torch.Tensor
|
||||||
:return: A tensor-like object with updated labels.
|
:return: A new :class:`LabelTensor` instance obtained __getitem__
|
||||||
|
operation on :class:`torch.Tensor` part of the instance, with the
|
||||||
|
updated labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
|
|
||||||
:raises KeyError: If an invalid label index is provided.
|
:raises KeyError: If an invalid label index is provided.
|
||||||
:raises IndexError: If an invalid index is accessed in the tensor.
|
:raises IndexError: If an invalid index is accessed in the tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Handle string index
|
# Handle string index
|
||||||
if isinstance(index, str) or (
|
if isinstance(index, str) or (
|
||||||
isinstance(index, (tuple, list))
|
isinstance(index, (tuple, list))
|
||||||
@@ -516,12 +585,11 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def sort_labels(self, dim=None):
|
def sort_labels(self, dim=None):
|
||||||
"""
|
"""
|
||||||
Sorts the labels along a specified dimension and returns a new tensor
|
Sort the labels along the specified dimension and apply the same sorting
|
||||||
with sorted labels.
|
to the :class:`torch.Tensor` part of the instance.
|
||||||
|
|
||||||
:param dim: The dimension along which to sort the labels. If `None`,
|
:param int dim: The dimension along which to sort the labels.
|
||||||
the last dimension (`ndim - 1`) is used.
|
If ``None``, the last dimension (``ndim - 1``) is used.
|
||||||
:type dim: int, optional
|
|
||||||
:return: A new tensor with sorted labels along the specified dimension.
|
:return: A new tensor with sorted labels along the specified dimension.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
@@ -543,13 +611,15 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
"""
|
"""
|
||||||
Creates a deep copy of the object.
|
Creates a deep copy of the object. For more details, see
|
||||||
|
:meth:`copy.deepcopy`.
|
||||||
|
|
||||||
:param memo: LabelTensor object to be copied.
|
:param memo: LabelTensor object to be copied.
|
||||||
:type memo: LabelTensor
|
:type memo: LabelTensor
|
||||||
:return: A deep copy of the original LabelTensor object.
|
:return: A deep copy of the original LabelTensor object.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels))
|
result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels))
|
||||||
return result
|
return result
|
||||||
@@ -557,7 +627,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
def permute(self, *dims):
|
def permute(self, *dims):
|
||||||
"""
|
"""
|
||||||
Permutes the dimensions of the tensor and the associated labels
|
Permutes the dimensions of the tensor and the associated labels
|
||||||
accordingly.
|
accordingly. For more details, see :meth:`torch.Tensor.permute`.
|
||||||
|
|
||||||
:param dims: The dimensions to permute the tensor to.
|
:param dims: The dimensions to permute the tensor to.
|
||||||
:type dims: tuple, list
|
:type dims: tuple, list
|
||||||
@@ -579,11 +649,12 @@ class LabelTensor(torch.Tensor):
|
|||||||
def detach(self):
|
def detach(self):
|
||||||
"""
|
"""
|
||||||
Detaches the tensor from the computation graph and retains the stored
|
Detaches the tensor from the computation graph and retains the stored
|
||||||
labels.
|
labels. For more details, see :meth:`torch.Tensor.detach`.
|
||||||
|
|
||||||
:return: A new tensor detached from the computation graph.
|
:return: A new tensor detached from the computation graph.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lt = super().detach()
|
lt = super().detach()
|
||||||
|
|
||||||
# Copy the labels to the new tensor only if present
|
# Copy the labels to the new tensor only if present
|
||||||
@@ -594,14 +665,15 @@ class LabelTensor(torch.Tensor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def summation(tensors):
|
def summation(tensors):
|
||||||
"""
|
"""
|
||||||
Computes the summation of a list of tensors.
|
Computes the summation of a list of :class:`LabelTensor` instances.
|
||||||
|
|
||||||
:param tensors: A list of tensors to sum. All tensors must have the same
|
|
||||||
shape and labels.
|
:param list(LabelTensor) tensors: A list of tensors to sum. All tensors
|
||||||
:type tensors: list of LabelTensor
|
must have the same shape and labels.
|
||||||
:return: A new `LabelTensor` containing the element-wise sum of the
|
:return: A new `LabelTensor` containing the element-wise sum of the
|
||||||
input tensors.
|
input tensors.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
|
|
||||||
:raises ValueError: If the input `tensors` list is empty.
|
:raises ValueError: If the input `tensors` list is empty.
|
||||||
:raises RuntimeError: If the tensors have different shapes and/or
|
:raises RuntimeError: If the tensors have different shapes and/or
|
||||||
mismatched labels.
|
mismatched labels.
|
||||||
@@ -637,12 +709,14 @@ class LabelTensor(torch.Tensor):
|
|||||||
def reshape(self, *shape):
|
def reshape(self, *shape):
|
||||||
"""
|
"""
|
||||||
Override the reshape method to update the labels of the tensor.
|
Override the reshape method to update the labels of the tensor.
|
||||||
|
For more details, see :meth:`torch.Tensor.reshape`.
|
||||||
|
|
||||||
:param shape: The new shape of the tensor.
|
:param tuple shape: The new shape of the tensor.
|
||||||
:type shape: tuple
|
:return: A new :class:`LabelTensor` instance with the updated shape and
|
||||||
:return: A tensor-like object with updated labels.
|
labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# As for now the reshape method is used only in the context of the
|
# As for now the reshape method is used only in the context of the
|
||||||
# dataset, the labels are not
|
# dataset, the labels are not
|
||||||
tensor = super().reshape(*shape)
|
tensor = super().reshape(*shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user