Small fixes in conditions
This commit is contained in:
committed by
Nicola Demo
parent
63028f2b01
commit
56e45c6c87
@@ -56,23 +56,23 @@ class Condition:
|
|||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> from pina import Condition
|
>>> from pina import Condition
|
||||||
>>> condition = Condition(
|
>>> condition = Condition(
|
||||||
... input=input,
|
... input=input,
|
||||||
... target=target
|
... target=target
|
||||||
... )
|
... )
|
||||||
>>> condition = Condition(
|
>>> condition = Condition(
|
||||||
... domain=location,
|
... domain=location,
|
||||||
... equation=equation
|
... equation=equation
|
||||||
... )
|
... )
|
||||||
>>> condition = Condition(
|
>>> condition = Condition(
|
||||||
... input=input,
|
... input=input,
|
||||||
... equation=equation
|
... equation=equation
|
||||||
... )
|
... )
|
||||||
>>> condition = Condition(
|
>>> condition = Condition(
|
||||||
... input=data,
|
... input=data,
|
||||||
... conditional_variables=conditional_variables
|
... conditional_variables=conditional_variables
|
||||||
... )
|
... )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = list(
|
__slots__ = list(
|
||||||
@@ -87,6 +87,7 @@ class Condition:
|
|||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a new condition object based on the keyword arguments passed.
|
Create a new condition object based on the keyword arguments passed.
|
||||||
|
|
||||||
- ``input`` and ``target``: :class:`InputTargetCondition`
|
- ``input`` and ``target``: :class:`InputTargetCondition`
|
||||||
- ``domain`` and ``equation``: :class:`DomainEquationCondition`
|
- ``domain`` and ``equation``: :class:`DomainEquationCondition`
|
||||||
- ``input`` and ``equation``: :class:`InputEquationCondition`
|
- ``input`` and ``equation``: :class:`InputEquationCondition`
|
||||||
@@ -95,8 +96,8 @@ class Condition:
|
|||||||
|
|
||||||
:raises ValueError: No valid condition has been found.
|
:raises ValueError: No valid condition has been found.
|
||||||
:return: A new condition instance belonging to the proper class.
|
:return: A new condition instance belonging to the proper class.
|
||||||
:rtype: ConditionInputTarget | ConditionInputEquation |
|
:rtype: InputTargetCondition | DomainEquationCondition |
|
||||||
ConditionDomainEquation | ConditionData
|
InputEquationCondition | DataCondition
|
||||||
"""
|
"""
|
||||||
if len(args) != 0:
|
if len(args) != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -41,12 +41,14 @@ class ConditionInterface(metaclass=ABCMeta):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_graph_list_consistency(data_list):
|
def _check_graph_list_consistency(data_list):
|
||||||
"""
|
"""
|
||||||
Check if the list of Data/Graph objects is consistent.
|
Check if the list of :class:`torch_geometric.data.Data`/:class:`Graph`
|
||||||
|
objects is consistent.
|
||||||
|
|
||||||
:param data_list: list of Data/Graph objects.
|
:param data_list: List of graph type objects.
|
||||||
:type data_list: list(Data) | list(Graph)
|
:type data_list: list(torch_geometric.data.Data) | list(Graph)
|
||||||
|
|
||||||
:raises ValueError: Input data must be either Data or Graph objects.
|
:raises ValueError: Input data must be either torch_geometric.data.Data
|
||||||
|
or Graph objects.
|
||||||
:raises ValueError: All elements in the list must have the same keys.
|
:raises ValueError: All elements in the list must have the same keys.
|
||||||
:raises ValueError: Type mismatch in data tensors.
|
:raises ValueError: Type mismatch in data tensors.
|
||||||
:raises ValueError: Label mismatch in LabelTensors.
|
:raises ValueError: Label mismatch in LabelTensors.
|
||||||
|
|||||||
@@ -11,10 +11,9 @@ from ..graph import Graph
|
|||||||
|
|
||||||
class DataCondition(ConditionInterface):
|
class DataCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition for data. This condition must be used every
|
This condition must be used every time a Unsupervised Loss is needed in
|
||||||
time a Unsupervised Loss is needed in the Solver. The conditionalvariable
|
the Solver. The conditionalvariable can be passed as extra-input when
|
||||||
can be passed as extra-input when the model learns a conditional
|
the model learns a conditional distribution.
|
||||||
distribution
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["input", "conditional_variables"]
|
__slots__ = ["input", "conditional_variables"]
|
||||||
@@ -27,14 +26,16 @@ class DataCondition(ConditionInterface):
|
|||||||
types of input data.
|
types of input data.
|
||||||
|
|
||||||
:param input: Input data for the condition.
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor | LabelTensor | Graph | Data
|
:type input: torch.Tensor | LabelTensor | Graph | \
|
||||||
|
torch_geometric.data.Data
|
||||||
:param conditional_variables: Conditional variables for the condition.
|
:param conditional_variables: Conditional variables for the condition.
|
||||||
:type conditional_variables: torch.Tensor | LabelTensor
|
:type conditional_variables: torch.Tensor | LabelTensor
|
||||||
:return: Subclass of DataCondition.
|
:return: Subclass of DataCondition.
|
||||||
:rtype: TensorDataCondition | GraphDataCondition
|
:rtype: TensorDataCondition | GraphDataCondition
|
||||||
|
|
||||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||||
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
|
:class:`LabelTensor`, :class:`Graph`, or
|
||||||
|
:class:`torch_geometric.data.Data`.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -51,7 +52,7 @@ class DataCondition(ConditionInterface):
|
|||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid input types. "
|
"Invalid input types. "
|
||||||
"Please provide either Data or Graph objects."
|
"Please provide either torch_geometric.data.Data or Graph objects."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, input, conditional_variables=None):
|
def __init__(self, input, conditional_variables=None):
|
||||||
@@ -60,14 +61,14 @@ class DataCondition(ConditionInterface):
|
|||||||
variables (if any).
|
variables (if any).
|
||||||
|
|
||||||
:param input: Input data for the condition.
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor or Graph or Data
|
:type input: torch.Tensor or Graph or torch_geometric.data.Data
|
||||||
:param conditional_variables: Conditional variables for the condition.
|
:param conditional_variables: Conditional variables for the condition.
|
||||||
:type conditional_variables: torch.Tensor or LabelTensor
|
:type conditional_variables: torch.Tensor or LabelTensor
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If either `input` is composed by a list of :class:`Graph`/
|
If either `input` is composed by a list of :class:`Graph`/
|
||||||
:class:`Data` objects, all elements must have the same structure
|
:class:`torch_geometric.data.Data` objects, all elements must have
|
||||||
(keys and data types)
|
the same structure (keys and data types)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input = input
|
self.input = input
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
|
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
|
||||||
|
|
||||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||||
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
|
:class:`LabelTensor`, :class:`Graph`, or
|
||||||
|
:class:`torch_geometric.data.Data`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If the class is already a subclass, return the instance
|
# If the class is already a subclass, return the instance
|
||||||
@@ -59,16 +60,16 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
"""
|
"""
|
||||||
Initialize the InputEquationCondition by storing the input and equation.
|
Initialize the InputEquationCondition by storing the input and equation.
|
||||||
|
|
||||||
:param input: torch.Tensor or Graph/Data object containing the input.
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor | Graph
|
:type input: torch.Tensor | Graph
|
||||||
:param EquationInterface equation: Equation object containing the
|
:param EquationInterface equation: Equation object containing the
|
||||||
equation function.
|
equation function.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If ``input`` is composed by a list of :class:`Graph`/:class:`Data`
|
If ``input`` is composed by a list of :class:`Graph`/
|
||||||
objects, all elements must have the same structure (keys and data
|
:class:`torch_geometric.data.Data` objects, all elements must have
|
||||||
types). Moreover, at least one attribute must be a
|
the same structure (keys and data types). Moreover, at least one
|
||||||
:class:`LabelTensor`.
|
attribute must be a :class:`LabelTensor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -103,7 +104,7 @@ class InputGraphEquationCondition(InputEquationCondition):
|
|||||||
Check if at least one LabelTensor is present in the Graph object.
|
Check if at least one LabelTensor is present in the Graph object.
|
||||||
|
|
||||||
:param input: Input data.
|
:param input: Input data.
|
||||||
:type input: torch.Tensor or Graph or Data
|
:type input: torch.Tensor | Graph | torch_geometric.data.Data
|
||||||
|
|
||||||
:raises ValueError: If the input data object does not contain at least
|
:raises ValueError: If the input data object does not contain at least
|
||||||
one LabelTensor.
|
one LabelTensor.
|
||||||
|
|||||||
@@ -25,19 +25,20 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
the types of input and target data.
|
the types of input and target data.
|
||||||
|
|
||||||
:param input: Input data for the condition.
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor | Graph | Data | list | tuple
|
:type input: torch.Tensor | Graph | torch_geometric.data.Data | list | \
|
||||||
|
tuple
|
||||||
:param target: Target data for the condition.
|
:param target: Target data for the condition.
|
||||||
Graph, Data, or list/tuple.
|
:type target: torch.Tensor | Graph | torch_geometric.data.Data | list \
|
||||||
:type target: torch.Tensor | Graph | Data | list | tuple
|
| tuple
|
||||||
:return: Subclass of InputTargetCondition
|
:return: Subclass of InputTargetCondition
|
||||||
:rtype: TensorInputTensorTargetCondition |
|
:rtype: TensorInputTensorTargetCondition | \
|
||||||
TensorInputGraphTargetCondition |
|
TensorInputGraphTargetCondition | \
|
||||||
GraphInputTensorTargetCondition |
|
GraphInputTensorTargetCondition | \
|
||||||
GraphInputGraphTargetCondition
|
GraphInputGraphTargetCondition
|
||||||
|
|
||||||
:raises ValueError: If input and or target are not of type
|
:raises ValueError: If input and or target are not of type
|
||||||
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
|
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
|
||||||
:class:`Data`.
|
:class:`torch_geometric.data.Data`.
|
||||||
"""
|
"""
|
||||||
if cls != InputTargetCondition:
|
if cls != InputTargetCondition:
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
@@ -71,23 +72,23 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid input/target types. "
|
"Invalid input/target types. "
|
||||||
"Please provide either Data, Graph, LabelTensor or torch.Tensor "
|
"Please provide either torch_geometric.data.Data, Graph, LabelTensor "
|
||||||
"objects."
|
"or torch.Tensor objects."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, input, target):
|
def __init__(self, input, target):
|
||||||
"""
|
"""
|
||||||
Initialize the InputTargetCondition, storing the input and target data.
|
Initialize the InputTargetCondition, storing the input and target data.
|
||||||
|
|
||||||
:param input: torch.Tensor or Graph/Data object containing the input
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor | Graph or Data
|
:type input: torch.Tensor | Graph | torch_geometric.data.Data
|
||||||
:param target: torch.Tensor or Graph/Data object containing the target
|
:param target: Target data for the condition.
|
||||||
:type target: torch.Tensor or Graph or Data
|
:type target: torch.Tensor | Graph | torch_geometric.data.Data
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If either ``input`` or ``target`` are composed by a list of
|
If either ``input`` or ``target`` are composed by a list of
|
||||||
:class:`Graph`/:class:`Data` objects, all elements must have the
|
:class:`Graph`/:class:`torch_geometric.data.Data` objects, all
|
||||||
same structure (keys and data types)
|
elements must have the same structure (keys and data types)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -117,19 +118,20 @@ class TensorInputTensorTargetCondition(InputTargetCondition):
|
|||||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
|
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
|
||||||
input and :class:`Graph`/:class:`Data` target data.
|
input and :class:`Graph`/:class:`torch_geometric.data.Data` target data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
|
InputTargetCondition subclass for :class:`Graph`/
|
||||||
:class:`torch.Tensor`/:class:`LabelTensor` target data.
|
:class:`torch_geometric.data.Data` input and :class:`torch.Tensor`/
|
||||||
|
:class:`LabelTensor` target data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
|
InputTargetCondition subclass for :class:`Graph`/
|
||||||
target data.
|
:class:`torch_geometric.data.Data` input and target data.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user