Doc conditions

This commit is contained in:
FilippoOlivo
2025-03-11 15:22:35 +01:00
parent 92bb04fafe
commit be0e39a050
7 changed files with 119 additions and 54 deletions

View File

@@ -23,17 +23,20 @@ class DataCondition(ConditionInterface):
def __new__(cls, input, conditional_variables=None):
"""
Instanciate the correct subclass of DataCondition by checking the type
of the input data (input and conditional_variables).
Instantiate the appropriate subclass of DataCondition based on the
types of input data.
:param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor
:return: Subclass of DataCondition.
:rtype: TensorDataCondition | GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
:param input: torch.Tensor or Graph/Data object containing the input
data
:type input: torch.Tensor or Graph or Data
:param conditional_variables: torch.Tensor or LabelTensor containing
the conditional variables
:type conditional_variables: torch.Tensor or LabelTensor
:return: DataCondition subclass
:rtype: TensorDataCondition or GraphDataCondition
"""
if cls != DataCondition:
return super().__new__(cls)
@@ -56,12 +59,15 @@ class DataCondition(ConditionInterface):
Initialize the DataCondition, storing the input and conditional
variables (if any).
:param input: torch.Tensor or Graph/Data object containing the input
data
:param input: Input data for the condition.
:type input: torch.Tensor or Graph or Data
:param conditional_variables: torch.Tensor or LabelTensor containing
the conditional variables
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor or LabelTensor
.. note::
If either `input` is composed by a list of :class:`Graph`/
:class:`Data` objects, all elements must have the same structure
(keys and data types)
"""
super().__init__()
self.input = input