Doc conditions

This commit is contained in:
FilippoOlivo
2025-03-11 15:22:35 +01:00
parent 92bb04fafe
commit be0e39a050
7 changed files with 119 additions and 54 deletions

View File

@@ -22,15 +22,18 @@ class InputEquationCondition(ConditionInterface):
def __new__(cls, input, equation):
"""
Instanciate the correct subclass of InputEquationCondition by checking
the type of the input data (only `input`).
Instantiate the appropriate subclass of InputEquationCondition based on
the type of input data.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data
:param input: Input data. It can be a LabelTensor or a Graph object.
:type input: LabelTensor | Graph
:param EquationInterface equation: Equation object containing the
equation function
:return: InputEquationCondition subclass
:rtype: InputTensorEquationCondition or InputGraphEquationCondition
equation function.
:return: Subclass of InputEquationCondition, based on the input type.
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
"""
# If the class is already a subclass, return the instance
@@ -56,11 +59,18 @@ class InputEquationCondition(ConditionInterface):
"""
Initialize the InputEquationCondition by storing the input and equation.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data
:param input: torch.Tensor or Graph/Data object containing the input.
:type input: torch.Tensor | Graph
:param EquationInterface equation: Equation object containing the
equation function
equation function.
.. note::
If ``input`` is composed by a list of :class:`Graph`/:class:`Data`
objects, all elements must have the same structure (keys and data
types). Moreover, at least one attribute must be a
:class:`LabelTensor`.
"""
super().__init__()
self.input = input
self.equation = equation
@@ -90,11 +100,15 @@ class InputGraphEquationCondition(InputEquationCondition):
@staticmethod
def _check_label_tensor(input):
"""
Check if the input is a LabelTensor.
Check if at least one LabelTensor is present in the Graph object.
:param input: input data
:param input: Input data.
:type input: torch.Tensor or Graph or Data
:raises ValueError: If the input data object does not contain at least
one LabelTensor.
"""
# Store the fist element of the list/tuple if input is a list/tuple
# it is anougth to check the first element because all elements must
# have the same type and structure (already checked)