Small fixes in conditions
This commit is contained in:
@@ -25,19 +25,20 @@ class InputTargetCondition(ConditionInterface):
|
||||
the types of input and target data.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | Graph | Data | list | tuple
|
||||
:type input: torch.Tensor | Graph | torch_geometric.data.Data | list | \
|
||||
tuple
|
||||
:param target: Target data for the condition.
|
||||
Graph, Data, or list/tuple.
|
||||
:type target: torch.Tensor | Graph | Data | list | tuple
|
||||
:type target: torch.Tensor | Graph | torch_geometric.data.Data | list \
|
||||
| tuple
|
||||
:return: Subclass of InputTargetCondition
|
||||
:rtype: TensorInputTensorTargetCondition |
|
||||
TensorInputGraphTargetCondition |
|
||||
GraphInputTensorTargetCondition |
|
||||
:rtype: TensorInputTensorTargetCondition | \
|
||||
TensorInputGraphTargetCondition | \
|
||||
GraphInputTensorTargetCondition | \
|
||||
GraphInputGraphTargetCondition
|
||||
|
||||
:raises ValueError: If input and or target are not of type
|
||||
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
|
||||
:class:`Data`.
|
||||
:class:`torch_geometric.data.Data`.
|
||||
"""
|
||||
if cls != InputTargetCondition:
|
||||
return super().__new__(cls)
|
||||
@@ -71,23 +72,23 @@ class InputTargetCondition(ConditionInterface):
|
||||
|
||||
raise ValueError(
|
||||
"Invalid input/target types. "
|
||||
"Please provide either Data, Graph, LabelTensor or torch.Tensor "
|
||||
"objects."
|
||||
"Please provide either torch_geometric.data.Data, Graph, LabelTensor "
|
||||
"or torch.Tensor objects."
|
||||
)
|
||||
|
||||
def __init__(self, input, target):
|
||||
"""
|
||||
Initialize the InputTargetCondition, storing the input and target data.
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
:type input: torch.Tensor | Graph or Data
|
||||
:param target: torch.Tensor or Graph/Data object containing the target
|
||||
:type target: torch.Tensor or Graph or Data
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | Graph | torch_geometric.data.Data
|
||||
:param target: Target data for the condition.
|
||||
:type target: torch.Tensor | Graph | torch_geometric.data.Data
|
||||
|
||||
.. note::
|
||||
If either ``input`` or ``target`` are composed by a list of
|
||||
:class:`Graph`/:class:`Data` objects, all elements must have the
|
||||
same structure (keys and data types)
|
||||
:class:`Graph`/:class:`torch_geometric.data.Data` objects, all
|
||||
elements must have the same structure (keys and data types)
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@@ -117,19 +118,20 @@ class TensorInputTensorTargetCondition(InputTargetCondition):
|
||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
|
||||
input and :class:`Graph`/:class:`Data` target data.
|
||||
input and :class:`Graph`/:class:`torch_geometric.data.Data` target data.
|
||||
"""
|
||||
|
||||
|
||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
|
||||
:class:`torch.Tensor`/:class:`LabelTensor` target data.
|
||||
InputTargetCondition subclass for :class:`Graph`/
|
||||
:class:`torch_geometric.data.Data` input and :class:`torch.Tensor`/
|
||||
:class:`LabelTensor` target data.
|
||||
"""
|
||||
|
||||
|
||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
|
||||
target data.
|
||||
InputTargetCondition subclass for :class:`Graph`/
|
||||
:class:`torch_geometric.data.Data` input and target data.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user