Small fixes in conditions

This commit is contained in:
FilippoOlivo
2025-03-11 16:26:28 +01:00
committed by Nicola Demo
parent 63028f2b01
commit 56e45c6c87
5 changed files with 67 additions and 60 deletions

View File

@@ -25,19 +25,20 @@ class InputTargetCondition(ConditionInterface):
the types of input and target data.
:param input: Input data for the condition.
:type input: torch.Tensor | Graph | Data | list | tuple
:type input: torch.Tensor | Graph | torch_geometric.data.Data | list | \
tuple
:param target: Target data for the condition.
Graph, Data, or list/tuple.
:type target: torch.Tensor | Graph | Data | list | tuple
:type target: torch.Tensor | Graph | torch_geometric.data.Data | list \
| tuple
:return: Subclass of InputTargetCondition
:rtype: TensorInputTensorTargetCondition |
TensorInputGraphTargetCondition |
GraphInputTensorTargetCondition |
:rtype: TensorInputTensorTargetCondition | \
TensorInputGraphTargetCondition | \
GraphInputTensorTargetCondition | \
GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
:class:`Data`.
:class:`torch_geometric.data.Data`.
"""
if cls != InputTargetCondition:
return super().__new__(cls)
@@ -71,23 +72,23 @@ class InputTargetCondition(ConditionInterface):
raise ValueError(
"Invalid input/target types. "
"Please provide either Data, Graph, LabelTensor or torch.Tensor "
"objects."
"Please provide either torch_geometric.data.Data, Graph, LabelTensor "
"or torch.Tensor objects."
)
def __init__(self, input, target):
"""
Initialize the InputTargetCondition, storing the input and target data.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor | Graph or Data
:param target: torch.Tensor or Graph/Data object containing the target
:type target: torch.Tensor or Graph or Data
:param input: Input data for the condition.
:type input: torch.Tensor | Graph | torch_geometric.data.Data
:param target: Target data for the condition.
:type target: torch.Tensor | Graph | torch_geometric.data.Data
.. note::
If either ``input`` or ``target`` are composed by a list of
:class:`Graph`/:class:`Data` objects, all elements must have the
same structure (keys and data types)
:class:`Graph`/:class:`torch_geometric.data.Data` objects, all
elements must have the same structure (keys and data types)
"""
super().__init__()
@@ -117,19 +118,20 @@ class TensorInputTensorTargetCondition(InputTargetCondition):
class TensorInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
input and :class:`Graph`/:class:`Data` target data.
input and :class:`Graph`/:class:`torch_geometric.data.Data` target data.
"""
class GraphInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
:class:`torch.Tensor`/:class:`LabelTensor` target data.
InputTargetCondition subclass for :class:`Graph`/
:class:`torch_geometric.data.Data` input and :class:`torch.Tensor`/
:class:`LabelTensor` target data.
"""
class GraphInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
target data.
InputTargetCondition subclass for :class:`Graph`/
:class:`torch_geometric.data.Data` input and target data.
"""