Small fixes in conditions

This commit is contained in:
FilippoOlivo
2025-03-11 16:26:28 +01:00
parent be0e39a050
commit 0a1963b204
5 changed files with 67 additions and 60 deletions

View File

@@ -11,10 +11,9 @@ from ..graph import Graph
class DataCondition(ConditionInterface):
"""
Condition for data. This condition must be used every
time a Unsupervised Loss is needed in the Solver. The conditionalvariable
can be passed as extra-input when the model learns a conditional
distribution
This condition must be used every time a Unsupervised Loss is needed in
the Solver. The conditionalvariable can be passed as extra-input when
the model learns a conditional distribution.
"""
__slots__ = ["input", "conditional_variables"]
@@ -27,14 +26,16 @@ class DataCondition(ConditionInterface):
types of input data.
:param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data
:type input: torch.Tensor | LabelTensor | Graph | \
torch_geometric.data.Data
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor
:return: Subclass of DataCondition.
:rtype: TensorDataCondition | GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
:class:`LabelTensor`, :class:`Graph`, or
:class:`torch_geometric.data.Data`.
"""
@@ -51,7 +52,7 @@ class DataCondition(ConditionInterface):
raise ValueError(
"Invalid input types. "
"Please provide either Data or Graph objects."
"Please provide either torch_geometric.data.Data or Graph objects."
)
def __init__(self, input, conditional_variables=None):
@@ -60,14 +61,14 @@ class DataCondition(ConditionInterface):
variables (if any).
:param input: Input data for the condition.
:type input: torch.Tensor or Graph or Data
:type input: torch.Tensor or Graph or torch_geometric.data.Data
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor or LabelTensor
.. note::
If either `input` is composed by a list of :class:`Graph`/
:class:`Data` objects, all elements must have the same structure
(keys and data types)
:class:`torch_geometric.data.Data` objects, all elements must have
the same structure (keys and data types)
"""
super().__init__()
self.input = input