Small fixes in conditions
This commit is contained in:
@@ -11,10 +11,9 @@ from ..graph import Graph
|
||||
|
||||
class DataCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for data. This condition must be used every
|
||||
time a Unsupervised Loss is needed in the Solver. The conditionalvariable
|
||||
can be passed as extra-input when the model learns a conditional
|
||||
distribution
|
||||
This condition must be used every time a Unsupervised Loss is needed in
|
||||
the Solver. The conditionalvariable can be passed as extra-input when
|
||||
the model learns a conditional distribution.
|
||||
"""
|
||||
|
||||
__slots__ = ["input", "conditional_variables"]
|
||||
@@ -27,14 +26,16 @@ class DataCondition(ConditionInterface):
|
||||
types of input data.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | LabelTensor | Graph | Data
|
||||
:type input: torch.Tensor | LabelTensor | Graph | \
|
||||
torch_geometric.data.Data
|
||||
:param conditional_variables: Conditional variables for the condition.
|
||||
:type conditional_variables: torch.Tensor | LabelTensor
|
||||
:return: Subclass of DataCondition.
|
||||
:rtype: TensorDataCondition | GraphDataCondition
|
||||
|
||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
|
||||
:class:`LabelTensor`, :class:`Graph`, or
|
||||
:class:`torch_geometric.data.Data`.
|
||||
|
||||
|
||||
"""
|
||||
@@ -51,7 +52,7 @@ class DataCondition(ConditionInterface):
|
||||
|
||||
raise ValueError(
|
||||
"Invalid input types. "
|
||||
"Please provide either Data or Graph objects."
|
||||
"Please provide either torch_geometric.data.Data or Graph objects."
|
||||
)
|
||||
|
||||
def __init__(self, input, conditional_variables=None):
|
||||
@@ -60,14 +61,14 @@ class DataCondition(ConditionInterface):
|
||||
variables (if any).
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:type input: torch.Tensor or Graph or torch_geometric.data.Data
|
||||
:param conditional_variables: Conditional variables for the condition.
|
||||
:type conditional_variables: torch.Tensor or LabelTensor
|
||||
|
||||
.. note::
|
||||
If either `input` is composed by a list of :class:`Graph`/
|
||||
:class:`Data` objects, all elements must have the same structure
|
||||
(keys and data types)
|
||||
:class:`torch_geometric.data.Data` objects, all elements must have
|
||||
the same structure (keys and data types)
|
||||
"""
|
||||
super().__init__()
|
||||
self.input = input
|
||||
|
||||
Reference in New Issue
Block a user