Doc conditions

This commit is contained in:
FilippoOlivo
2025-03-11 15:22:35 +01:00
committed by Nicola Demo
parent 0e5275f4c0
commit 63028f2b01
7 changed files with 119 additions and 54 deletions

View File

@@ -21,17 +21,23 @@ class InputTargetCondition(ConditionInterface):
def __new__(cls, input, target):
"""
Instanciate the correct subclass of InputTargetCondition by checking the
type of the input and target data.
Instantiate the appropriate subclass of InputTargetCondition based on
the types of input and target data.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data
:param target: torch.Tensor or Graph/Data object containing the target
:type target: torch.Tensor or Graph or Data
:return: InputTargetCondition subclass
:rtype: TensorInputTensorTargetCondition or
TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
or GraphInputGraphTargetCondition
:param input: Input data for the condition.
:type input: torch.Tensor | Graph | Data | list | tuple
:param target: Target data for the condition.
Graph, Data, or list/tuple.
:type target: torch.Tensor | Graph | Data | list | tuple
:return: Subclass of InputTargetCondition
:rtype: TensorInputTensorTargetCondition |
TensorInputGraphTargetCondition |
GraphInputTensorTargetCondition |
GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
:class:`Data`.
"""
if cls != InputTargetCondition:
return super().__new__(cls)
@@ -74,10 +80,16 @@ class InputTargetCondition(ConditionInterface):
Initialize the InputTargetCondition, storing the input and target data.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data
:type input: torch.Tensor | Graph or Data
:param target: torch.Tensor or Graph/Data object containing the target
:type target: torch.Tensor or Graph or Data
.. note::
If either ``input`` or ``target`` are composed by a list of
:class:`Graph`/:class:`Data` objects, all elements must have the
same structure (keys and data types)
"""
super().__init__()
self._check_input_target_len(input, target)
self.input = input
@@ -97,25 +109,27 @@ class InputTargetCondition(ConditionInterface):
class TensorInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for torch.Tensor input and target data.
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
input and target data.
"""
class TensorInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for torch.Tensor input and Graph/Data target
data.
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
input and :class:`Graph`/:class:`Data` target data.
"""
class GraphInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for Graph/Data input and torch.Tensor target
data.
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
:class:`torch.Tensor`/:class:`LabelTensor` target data.
"""
class GraphInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for Graph/Data input and target data.
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
target data.
"""