Fix rendering LT

This commit is contained in:
FilippoOlivo
2025-03-14 11:41:23 +01:00
parent feb6ca952a
commit bd65d4b0f7

View File

@@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor):
@property @property
def tensor(self): def tensor(self):
""" """
Give the tensor part of the :class:`~pina.label_tensor.LabelTensor` Returns the tensor part of the :class:`~pina.label_tensor.LabelTensor`
object. object.
:return: tensor part of the :class:`~pina.label_tensor.LabelTensor`. :return: Tensor part of the :class:`~pina.label_tensor.LabelTensor`.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
@@ -44,10 +44,15 @@ class LabelTensor(torch.Tensor):
def __init__(self, x, labels): def __init__(self, x, labels):
""" """
Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of Initialize the :class:`~pina.label_tensor.LabelTensor` instance, by
the labels and a :class:`torch.Tensor`. Internally, the initialization checking the consistency of the labels and the tensor. Specifically, the
method will store check the compatibility of the labels with the tensor labels must match the following conditions:
shape.
- At each dimension, the number of labels must match the size of the \
dimension.
- At each dimension, the labels must be unique.
The labels can be passed in the following formats:
:Example: :Example:
>>> from pina import LabelTensor >>> from pina import LabelTensor
@@ -58,10 +63,17 @@ class LabelTensor(torch.Tensor):
>>> torch.rand((2000, 3)), >>> torch.rand((2000, 3)),
... ["a", "b", "c"]) ... ["a", "b", "c"])
The keys of the dictionary are the dimension indices, and the values are
dictionaries containing the labels and the name of the dimension. If
the labels are passed as a list, these are assigned to the last
dimension.
:param torch.Tensor x: The tensor to be casted as a
:class:`~pina.label_tensor.LabelTensor`.
:param labels: Labels to assign to the tensor.
:type labels: str | list[str] | dict
:raises ValueError: If the labels are not consistent with the tensor.
""" """
# Avoid unused argument warning. x is not used in the constructor
# of the parent class.
# pylint: disable=unused-argument
super().__init__() super().__init__()
if labels is not None: if labels is not None:
self.labels = labels self.labels = labels
@@ -71,7 +83,7 @@ class LabelTensor(torch.Tensor):
@property @property
def full_labels(self): def full_labels(self):
""" """
Gives the full labels of the tensor, even for the dimensions that are Returns the full labels of the tensor, even for the dimensions that are
not labeled. not labeled.
:return: The full labels of the tensor :return: The full labels of the tensor
@@ -89,7 +101,7 @@ class LabelTensor(torch.Tensor):
@property @property
def stored_labels(self): def stored_labels(self):
""" """
Gives the labels stored inside the instance. Returns the labels stored inside the instance.
:return: The labels stored inside the instance. :return: The labels stored inside the instance.
:rtype: dict :rtype: dict
@@ -99,7 +111,7 @@ class LabelTensor(torch.Tensor):
@property @property
def labels(self): def labels(self):
""" """
Give the labels of the last dimension of the instance. Returns the labels of the last dimension of the instance.
:return: labels of last dimension :return: labels of last dimension
:rtype: list :rtype: list
@@ -111,8 +123,9 @@ class LabelTensor(torch.Tensor):
@labels.setter @labels.setter
def labels(self, labels): def labels(self, labels):
""" """
Set the parameter ``_labels`` by checking the type of the input labels Set labels stored insider the instance by checking the type of the
and handling it accordingly. The following types are accepted: input labels and handling it accordingly. The following types are
accepted:
- **list**: The list of labels is assigned to the last dimension. - **list**: The list of labels is assigned to the last dimension.
- **dict**: The dictionary of labels is assigned to the tensor. - **dict**: The dictionary of labels is assigned to the tensor.
@@ -134,7 +147,7 @@ class LabelTensor(torch.Tensor):
else: else:
raise ValueError("labels must be list, dict or string.") raise ValueError("labels must be list, dict or string.")
def _init_labels_from_dict(self, labels: dict): def _init_labels_from_dict(self, labels):
""" """
Store the internal label representation according to the values Store the internal label representation according to the values
passed as input. passed as input.
@@ -146,7 +159,7 @@ class LabelTensor(torch.Tensor):
tensor_shape = self.shape tensor_shape = self.shape
def validate_dof(dof_list, dim_size: int): def validate_dof(dof_list, dim_size):
"""Validate the 'dof' list for uniqueness and size.""" """Validate the 'dof' list for uniqueness and size."""
if len(dof_list) != len(set(dof_list)): if len(dof_list) != len(set(dof_list)):
raise ValueError("dof must be unique") raise ValueError("dof must be unique")
@@ -187,7 +200,7 @@ class LabelTensor(torch.Tensor):
def _init_labels_from_list(self, labels): def _init_labels_from_list(self, labels):
""" """
Given a ``list`` of dof, this method update the internal label Given a list of dof, this method update the internal label
representation by assigning the dof to the last dimension. representation by assigning the dof to the last dimension.
:param labels: The label(s) to update. :param labels: The label(s) to update.
@@ -203,17 +216,25 @@ class LabelTensor(torch.Tensor):
def extract(self, labels_to_extract): def extract(self, labels_to_extract):
""" """
Extract the subset of the original tensor by returning all the positions Extract the subset of the original tensor by returning all the positions
corresponding to the passed ``label_to_extract``. corresponding to the passed ``label_to_extract``. If ``label_to_extract``
is a dictionary, the keys are the dimension names and the values are the
labels to extract. If a single label or a list of labels is passed, the
last dimension is considered.
:param labels_to_extract: The label(s) to extract. If a single label or :Example:
a list of labels is passed, the last dimension is considered. >>> from pina import LabelTensor
If a dictionary is passed, the keys are the dimension names and the >>> labels = {1: {'dof': ["a", "b", "c"], 'name': 'space'}}
values are the labels to extract. >>> tensor = LabelTensor(torch.rand((2000, 3)), labels)
>>> tensor.extract("a")
>>> tensor.extract(["a", "b"])
>>> tensor.extract({"space": ["a", "b"]})
:param labels_to_extract: The label(s) to extract.
:type labels_to_extract: str | list[str] | tuple[str] | dict :type labels_to_extract: str | list[str] | tuple[str] | dict
:return: The extracted tensor with the updated labels. :return: The extracted tensor with the updated labels.
:rtype: LabelTensor :rtype: LabelTensor
:raises TypeError: Labels are not ``str``, ``list of str`` or ``dict`` :raises TypeError: Labels are not ``str``, ``list[str]`` or ``dict``
properly setted. properly setted.
:raises ValueError: Label to extract is not in the labels ``list``. :raises ValueError: Label to extract is not in the labels ``list``.
""" """
@@ -298,13 +319,13 @@ class LabelTensor(torch.Tensor):
:param list[LabelTensor] tensors: :param list[LabelTensor] tensors:
:class:`~pina.label_tensor.LabelTensor` instances to concatenate :class:`~pina.label_tensor.LabelTensor` instances to concatenate
:param int dim: dimensions on which you want to perform the operation :param int dim: Dimensions on which you want to perform the operation
(default is 0) (default is 0)
:return: A new :class:`LabelTensor' instance obtained by concatenating :return: A new :class:`LabelTensor` instance obtained by concatenating
the input instances, with the updated labels. the input instances.
:rtype: LabelTensor :rtype: LabelTensor
:raises ValueError: either number dof or dimensions names differ :raises ValueError: either number dof or dimensions names differ.
""" """
if not tensors: if not tensors:
@@ -353,7 +374,7 @@ class LabelTensor(torch.Tensor):
:param list[LabelTensor] tensors: A list of tensors to stack. :param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape. All tensors must have the same shape.
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors, with the updated labels. by stacking the input tensors.
:rtype: LabelTensor :rtype: LabelTensor
""" """
@@ -389,7 +410,7 @@ class LabelTensor(torch.Tensor):
Give the ``dtype`` of the tensor. For more details, see Give the ``dtype`` of the tensor. For more details, see
:meth:`torch.dtype`. :meth:`torch.dtype`.
:return: dtype of the tensor :return: The data type of the tensor.
:rtype: torch.dtype :rtype: torch.dtype
""" """
@@ -427,19 +448,19 @@ class LabelTensor(torch.Tensor):
def append(self, tensor, mode="std"): def append(self, tensor, mode="std"):
""" """
Appends a given tensor to the current tensor along the last dimension. Appends a given tensor to the current tensor along the last dimension.
This method supports two types of appending operations: This method supports two types of appending operations:
1. **Standard append** ("std"): Concatenates the input tensor with the
1. **Standard append** ("std"): Concatenates the input tensor with the \
current tensor along the last dimension. current tensor along the last dimension.
2. **Cross append** ("cross"): Creates a cross-product of the current 2. **Cross append** ("cross"): Creates a cross-product of the current \
tensor and the input tensor by repeating them in a cross-product tensor and the input tensor.
fashion, then concatenates the result along the last dimension.
:param tensor: The tensor to append to the current tensor. :param tensor: The tensor to append to the current tensor.
:type tensor: LabelTensor :type tensor: LabelTensor
:param mode: The append mode to use. Defaults to "std". :param mode: The append mode to use. Defaults to ``st``.
:type mode: str, optional :type mode: str, optional
:return: A new `LabelTensor` obtained by appending the input tensor. :return: A new :class:`LabelTensor` instance obtained by appending the
input tensor.
:rtype: LabelTensor :rtype: LabelTensor
:raises ValueError: If the mode is not "std" or "cross". :raises ValueError: If the mode is not "std" or "cross".
@@ -468,7 +489,7 @@ class LabelTensor(torch.Tensor):
raise ValueError('mode must be either "std" or "cross"') raise ValueError('mode must be either "std" or "cross"')
@staticmethod @staticmethod
def vstack(label_tensors): def vstack(tensors):
""" """
Stack tensors vertically. For more details, see :meth:`torch.vstack`. Stack tensors vertically. For more details, see :meth:`torch.vstack`.
@@ -480,7 +501,7 @@ class LabelTensor(torch.Tensor):
:rtype: LabelTensor :rtype: LabelTensor
""" """
return LabelTensor.cat(label_tensors, dim=0) return LabelTensor.cat(tensors, dim=0)
# This method is used to update labels # This method is used to update labels
def _update_single_label( def _update_single_label(
@@ -592,11 +613,11 @@ class LabelTensor(torch.Tensor):
def sort_labels(self, dim=None): def sort_labels(self, dim=None):
""" """
Sort the labels along the specified dimension and apply the same sorting Sort the labels along the specified dimension and apply. It applies the
to the :class:`torch.Tensor` part of the instance. same sorting to the tensor part of the instance.
:param int dim: The dimension along which to sort the labels. :param int dim: The dimension along which to sort the labels.
If ``None``, the last dimension (``ndim - 1``) is used. If ``None``, the last dimension is used.
:return: A new tensor with sorted labels along the specified dimension. :return: A new tensor with sorted labels along the specified dimension.
:rtype: LabelTensor :rtype: LabelTensor
""" """