Additional fix in condition

This commit is contained in:
FilippoOlivo
2025-03-12 23:14:02 +01:00
parent 972fa2101d
commit 8a3c169ffd
5 changed files with 34 additions and 40 deletions

View File

@@ -11,9 +11,8 @@ from ..graph import Graph
class DataCondition(ConditionInterface):
"""
This condition must be used every time a Unsupervised Loss is needed in
the Solver. The `conditional_variable` can be passed as extra-input when
the model learns a conditional distribution.
Condition defined by input data and conditional variables. It can be used
in unsupervised learning problems.
"""
__slots__ = ["input", "conditional_variables"]
@@ -22,14 +21,14 @@ class DataCondition(ConditionInterface):
def __new__(cls, input, conditional_variables=None):
"""
Instantiate the appropriate subclass of DataCondition based on the
types of input data.
Instantiate the appropriate subclass of :class:`DataCondition` based on
the type of `input`.
:param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph |
Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor
:type conditional_variables: torch.Tensor | LabelTensor, optional
:return: Subclass of DataCondition.
:rtype: pina.condition.data_condition.TensorDataCondition |
pina.condition.data_condition.GraphDataCondition
@@ -37,9 +36,8 @@ class DataCondition(ConditionInterface):
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
or :class:`~torch_geometric.data.Data`.
"""
if cls != DataCondition:
return super().__new__(cls)
if isinstance(input, (torch.Tensor, LabelTensor)):
@@ -69,8 +67,8 @@ class DataCondition(ConditionInterface):
.. note::
If either `input` is composed by a list of
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
objects, all elements must have the same structure (keys and data
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`,
all elements must have the same structure (keys and data
types)
"""