Fix rendering LT
This commit is contained in:
@@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def tensor(self):
|
def tensor(self):
|
||||||
"""
|
"""
|
||||||
Give the tensor part of the :class:`~pina.label_tensor.LabelTensor`
|
Returns the tensor part of the :class:`~pina.label_tensor.LabelTensor`
|
||||||
object.
|
object.
|
||||||
|
|
||||||
:return: tensor part of the :class:`~pina.label_tensor.LabelTensor`.
|
:return: Tensor part of the :class:`~pina.label_tensor.LabelTensor`.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -44,10 +44,15 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def __init__(self, x, labels):
|
def __init__(self, x, labels):
|
||||||
"""
|
"""
|
||||||
Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of
|
Initialize the :class:`~pina.label_tensor.LabelTensor` instance, by
|
||||||
the labels and a :class:`torch.Tensor`. Internally, the initialization
|
checking the consistency of the labels and the tensor. Specifically, the
|
||||||
method will store check the compatibility of the labels with the tensor
|
labels must match the following conditions:
|
||||||
shape.
|
|
||||||
|
- At each dimension, the number of labels must match the size of the \
|
||||||
|
dimension.
|
||||||
|
- At each dimension, the labels must be unique.
|
||||||
|
|
||||||
|
The labels can be passed in the following formats:
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> from pina import LabelTensor
|
>>> from pina import LabelTensor
|
||||||
@@ -57,11 +62,18 @@ class LabelTensor(torch.Tensor):
|
|||||||
>>> tensor = LabelTensor(
|
>>> tensor = LabelTensor(
|
||||||
>>> torch.rand((2000, 3)),
|
>>> torch.rand((2000, 3)),
|
||||||
... ["a", "b", "c"])
|
... ["a", "b", "c"])
|
||||||
|
|
||||||
|
The keys of the dictionary are the dimension indices, and the values are
|
||||||
|
dictionaries containing the labels and the name of the dimension. If
|
||||||
|
the labels are passed as a list, these are assigned to the last
|
||||||
|
dimension.
|
||||||
|
|
||||||
|
:param torch.Tensor x: The tensor to be casted as a
|
||||||
|
:class:`~pina.label_tensor.LabelTensor`.
|
||||||
|
:param labels: Labels to assign to the tensor.
|
||||||
|
:type labels: str | list[str] | dict
|
||||||
|
:raises ValueError: If the labels are not consistent with the tensor.
|
||||||
"""
|
"""
|
||||||
# Avoid unused argument warning. x is not used in the constructor
|
|
||||||
# of the parent class.
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
@@ -71,7 +83,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def full_labels(self):
|
def full_labels(self):
|
||||||
"""
|
"""
|
||||||
Gives the full labels of the tensor, even for the dimensions that are
|
Returns the full labels of the tensor, even for the dimensions that are
|
||||||
not labeled.
|
not labeled.
|
||||||
|
|
||||||
:return: The full labels of the tensor
|
:return: The full labels of the tensor
|
||||||
@@ -89,7 +101,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def stored_labels(self):
|
def stored_labels(self):
|
||||||
"""
|
"""
|
||||||
Gives the labels stored inside the instance.
|
Returns the labels stored inside the instance.
|
||||||
|
|
||||||
:return: The labels stored inside the instance.
|
:return: The labels stored inside the instance.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
@@ -99,7 +111,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
"""
|
"""
|
||||||
Give the labels of the last dimension of the instance.
|
Returns the labels of the last dimension of the instance.
|
||||||
|
|
||||||
:return: labels of last dimension
|
:return: labels of last dimension
|
||||||
:rtype: list
|
:rtype: list
|
||||||
@@ -111,8 +123,9 @@ class LabelTensor(torch.Tensor):
|
|||||||
@labels.setter
|
@labels.setter
|
||||||
def labels(self, labels):
|
def labels(self, labels):
|
||||||
"""
|
"""
|
||||||
Set the parameter ``_labels`` by checking the type of the input labels
|
Set labels stored insider the instance by checking the type of the
|
||||||
and handling it accordingly. The following types are accepted:
|
input labels and handling it accordingly. The following types are
|
||||||
|
accepted:
|
||||||
|
|
||||||
- **list**: The list of labels is assigned to the last dimension.
|
- **list**: The list of labels is assigned to the last dimension.
|
||||||
- **dict**: The dictionary of labels is assigned to the tensor.
|
- **dict**: The dictionary of labels is assigned to the tensor.
|
||||||
@@ -134,7 +147,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("labels must be list, dict or string.")
|
raise ValueError("labels must be list, dict or string.")
|
||||||
|
|
||||||
def _init_labels_from_dict(self, labels: dict):
|
def _init_labels_from_dict(self, labels):
|
||||||
"""
|
"""
|
||||||
Store the internal label representation according to the values
|
Store the internal label representation according to the values
|
||||||
passed as input.
|
passed as input.
|
||||||
@@ -146,7 +159,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
tensor_shape = self.shape
|
tensor_shape = self.shape
|
||||||
|
|
||||||
def validate_dof(dof_list, dim_size: int):
|
def validate_dof(dof_list, dim_size):
|
||||||
"""Validate the 'dof' list for uniqueness and size."""
|
"""Validate the 'dof' list for uniqueness and size."""
|
||||||
if len(dof_list) != len(set(dof_list)):
|
if len(dof_list) != len(set(dof_list)):
|
||||||
raise ValueError("dof must be unique")
|
raise ValueError("dof must be unique")
|
||||||
@@ -187,7 +200,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def _init_labels_from_list(self, labels):
|
def _init_labels_from_list(self, labels):
|
||||||
"""
|
"""
|
||||||
Given a ``list`` of dof, this method update the internal label
|
Given a list of dof, this method update the internal label
|
||||||
representation by assigning the dof to the last dimension.
|
representation by assigning the dof to the last dimension.
|
||||||
|
|
||||||
:param labels: The label(s) to update.
|
:param labels: The label(s) to update.
|
||||||
@@ -203,17 +216,25 @@ class LabelTensor(torch.Tensor):
|
|||||||
def extract(self, labels_to_extract):
|
def extract(self, labels_to_extract):
|
||||||
"""
|
"""
|
||||||
Extract the subset of the original tensor by returning all the positions
|
Extract the subset of the original tensor by returning all the positions
|
||||||
corresponding to the passed ``label_to_extract``.
|
corresponding to the passed ``label_to_extract``. If ``label_to_extract``
|
||||||
|
is a dictionary, the keys are the dimension names and the values are the
|
||||||
|
labels to extract. If a single label or a list of labels is passed, the
|
||||||
|
last dimension is considered.
|
||||||
|
|
||||||
:param labels_to_extract: The label(s) to extract. If a single label or
|
:Example:
|
||||||
a list of labels is passed, the last dimension is considered.
|
>>> from pina import LabelTensor
|
||||||
If a dictionary is passed, the keys are the dimension names and the
|
>>> labels = {1: {'dof': ["a", "b", "c"], 'name': 'space'}}
|
||||||
values are the labels to extract.
|
>>> tensor = LabelTensor(torch.rand((2000, 3)), labels)
|
||||||
|
>>> tensor.extract("a")
|
||||||
|
>>> tensor.extract(["a", "b"])
|
||||||
|
>>> tensor.extract({"space": ["a", "b"]})
|
||||||
|
|
||||||
|
:param labels_to_extract: The label(s) to extract.
|
||||||
:type labels_to_extract: str | list[str] | tuple[str] | dict
|
:type labels_to_extract: str | list[str] | tuple[str] | dict
|
||||||
:return: The extracted tensor with the updated labels.
|
:return: The extracted tensor with the updated labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
|
|
||||||
:raises TypeError: Labels are not ``str``, ``list of str`` or ``dict``
|
:raises TypeError: Labels are not ``str``, ``list[str]`` or ``dict``
|
||||||
properly setted.
|
properly setted.
|
||||||
:raises ValueError: Label to extract is not in the labels ``list``.
|
:raises ValueError: Label to extract is not in the labels ``list``.
|
||||||
"""
|
"""
|
||||||
@@ -298,13 +319,13 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
:param list[LabelTensor] tensors:
|
:param list[LabelTensor] tensors:
|
||||||
:class:`~pina.label_tensor.LabelTensor` instances to concatenate
|
:class:`~pina.label_tensor.LabelTensor` instances to concatenate
|
||||||
:param int dim: dimensions on which you want to perform the operation
|
:param int dim: Dimensions on which you want to perform the operation
|
||||||
(default is 0)
|
(default is 0)
|
||||||
:return: A new :class:`LabelTensor' instance obtained by concatenating
|
:return: A new :class:`LabelTensor` instance obtained by concatenating
|
||||||
the input instances, with the updated labels.
|
the input instances.
|
||||||
|
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
:raises ValueError: either number dof or dimensions names differ
|
:raises ValueError: either number dof or dimensions names differ.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not tensors:
|
if not tensors:
|
||||||
@@ -353,7 +374,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
:param list[LabelTensor] tensors: A list of tensors to stack.
|
:param list[LabelTensor] tensors: A list of tensors to stack.
|
||||||
All tensors must have the same shape.
|
All tensors must have the same shape.
|
||||||
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
|
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
|
||||||
by stacking the input tensors, with the updated labels.
|
by stacking the input tensors.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -389,7 +410,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
Give the ``dtype`` of the tensor. For more details, see
|
Give the ``dtype`` of the tensor. For more details, see
|
||||||
:meth:`torch.dtype`.
|
:meth:`torch.dtype`.
|
||||||
|
|
||||||
:return: dtype of the tensor
|
:return: The data type of the tensor.
|
||||||
:rtype: torch.dtype
|
:rtype: torch.dtype
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -427,19 +448,19 @@ class LabelTensor(torch.Tensor):
|
|||||||
def append(self, tensor, mode="std"):
|
def append(self, tensor, mode="std"):
|
||||||
"""
|
"""
|
||||||
Appends a given tensor to the current tensor along the last dimension.
|
Appends a given tensor to the current tensor along the last dimension.
|
||||||
|
|
||||||
This method supports two types of appending operations:
|
This method supports two types of appending operations:
|
||||||
1. **Standard append** ("std"): Concatenates the input tensor with the
|
|
||||||
|
1. **Standard append** ("std"): Concatenates the input tensor with the \
|
||||||
current tensor along the last dimension.
|
current tensor along the last dimension.
|
||||||
2. **Cross append** ("cross"): Creates a cross-product of the current
|
2. **Cross append** ("cross"): Creates a cross-product of the current \
|
||||||
tensor and the input tensor by repeating them in a cross-product
|
tensor and the input tensor.
|
||||||
fashion, then concatenates the result along the last dimension.
|
|
||||||
|
|
||||||
:param tensor: The tensor to append to the current tensor.
|
:param tensor: The tensor to append to the current tensor.
|
||||||
:type tensor: LabelTensor
|
:type tensor: LabelTensor
|
||||||
:param mode: The append mode to use. Defaults to "std".
|
:param mode: The append mode to use. Defaults to ``st``.
|
||||||
:type mode: str, optional
|
:type mode: str, optional
|
||||||
:return: A new `LabelTensor` obtained by appending the input tensor.
|
:return: A new :class:`LabelTensor` instance obtained by appending the
|
||||||
|
input tensor.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
|
|
||||||
:raises ValueError: If the mode is not "std" or "cross".
|
:raises ValueError: If the mode is not "std" or "cross".
|
||||||
@@ -468,7 +489,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
raise ValueError('mode must be either "std" or "cross"')
|
raise ValueError('mode must be either "std" or "cross"')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def vstack(label_tensors):
|
def vstack(tensors):
|
||||||
"""
|
"""
|
||||||
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
|
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
|
||||||
|
|
||||||
@@ -480,7 +501,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return LabelTensor.cat(label_tensors, dim=0)
|
return LabelTensor.cat(tensors, dim=0)
|
||||||
|
|
||||||
# This method is used to update labels
|
# This method is used to update labels
|
||||||
def _update_single_label(
|
def _update_single_label(
|
||||||
@@ -592,11 +613,11 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def sort_labels(self, dim=None):
|
def sort_labels(self, dim=None):
|
||||||
"""
|
"""
|
||||||
Sort the labels along the specified dimension and apply the same sorting
|
Sort the labels along the specified dimension and apply. It applies the
|
||||||
to the :class:`torch.Tensor` part of the instance.
|
same sorting to the tensor part of the instance.
|
||||||
|
|
||||||
:param int dim: The dimension along which to sort the labels.
|
:param int dim: The dimension along which to sort the labels.
|
||||||
If ``None``, the last dimension (``ndim - 1``) is used.
|
If ``None``, the last dimension is used.
|
||||||
:return: A new tensor with sorted labels along the specified dimension.
|
:return: A new tensor with sorted labels along the specified dimension.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user