Small fixes in conditions

This commit is contained in:
FilippoOlivo
2025-03-11 16:26:28 +01:00
committed by Nicola Demo
parent 63028f2b01
commit 56e45c6c87
5 changed files with 67 additions and 60 deletions

View File

@@ -56,23 +56,23 @@ class Condition:
Example:: Example::
>>> from pina import Condition >>> from pina import Condition
>>> condition = Condition( >>> condition = Condition(
... input=input, ... input=input,
... target=target ... target=target
... ) ... )
>>> condition = Condition( >>> condition = Condition(
... domain=location, ... domain=location,
... equation=equation ... equation=equation
... ) ... )
>>> condition = Condition( >>> condition = Condition(
... input=input, ... input=input,
... equation=equation ... equation=equation
... ) ... )
>>> condition = Condition( >>> condition = Condition(
... input=data, ... input=data,
... conditional_variables=conditional_variables ... conditional_variables=conditional_variables
... ) ... )
""" """
__slots__ = list( __slots__ = list(
@@ -87,6 +87,7 @@ class Condition:
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
Create a new condition object based on the keyword arguments passed. Create a new condition object based on the keyword arguments passed.
- ``input`` and ``target``: :class:`InputTargetCondition` - ``input`` and ``target``: :class:`InputTargetCondition`
- ``domain`` and ``equation``: :class:`DomainEquationCondition` - ``domain`` and ``equation``: :class:`DomainEquationCondition`
- ``input`` and ``equation``: :class:`InputEquationCondition` - ``input`` and ``equation``: :class:`InputEquationCondition`
@@ -95,8 +96,8 @@ class Condition:
:raises ValueError: No valid condition has been found. :raises ValueError: No valid condition has been found.
:return: A new condition instance belonging to the proper class. :return: A new condition instance belonging to the proper class.
:rtype: ConditionInputTarget | ConditionInputEquation | :rtype: InputTargetCondition | DomainEquationCondition |
ConditionDomainEquation | ConditionData InputEquationCondition | DataCondition
""" """
if len(args) != 0: if len(args) != 0:
raise ValueError( raise ValueError(

View File

@@ -41,12 +41,14 @@ class ConditionInterface(metaclass=ABCMeta):
@staticmethod @staticmethod
def _check_graph_list_consistency(data_list): def _check_graph_list_consistency(data_list):
""" """
Check if the list of Data/Graph objects is consistent. Check if the list of :class:`torch_geometric.data.Data`/:class:`Graph`
objects is consistent.
:param data_list: list of Data/Graph objects. :param data_list: List of graph type objects.
:type data_list: list(Data) | list(Graph) :type data_list: list(torch_geometric.data.Data) | list(Graph)
:raises ValueError: Input data must be either Data or Graph objects. :raises ValueError: Input data must be either torch_geometric.data.Data
or Graph objects.
:raises ValueError: All elements in the list must have the same keys. :raises ValueError: All elements in the list must have the same keys.
:raises ValueError: Type mismatch in data tensors. :raises ValueError: Type mismatch in data tensors.
:raises ValueError: Label mismatch in LabelTensors. :raises ValueError: Label mismatch in LabelTensors.

View File

@@ -11,10 +11,9 @@ from ..graph import Graph
class DataCondition(ConditionInterface): class DataCondition(ConditionInterface):
""" """
Condition for data. This condition must be used every This condition must be used every time a Unsupervised Loss is needed in
time a Unsupervised Loss is needed in the Solver. The conditionalvariable the Solver. The conditionalvariable can be passed as extra-input when
can be passed as extra-input when the model learns a conditional the model learns a conditional distribution.
distribution
""" """
__slots__ = ["input", "conditional_variables"] __slots__ = ["input", "conditional_variables"]
@@ -27,14 +26,16 @@ class DataCondition(ConditionInterface):
types of input data. types of input data.
:param input: Input data for the condition. :param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data :type input: torch.Tensor | LabelTensor | Graph | \
torch_geometric.data.Data
:param conditional_variables: Conditional variables for the condition. :param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor :type conditional_variables: torch.Tensor | LabelTensor
:return: Subclass of DataCondition. :return: Subclass of DataCondition.
:rtype: TensorDataCondition | GraphDataCondition :rtype: TensorDataCondition | GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`, :raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`. :class:`LabelTensor`, :class:`Graph`, or
:class:`torch_geometric.data.Data`.
""" """
@@ -51,7 +52,7 @@ class DataCondition(ConditionInterface):
raise ValueError( raise ValueError(
"Invalid input types. " "Invalid input types. "
"Please provide either Data or Graph objects." "Please provide either torch_geometric.data.Data or Graph objects."
) )
def __init__(self, input, conditional_variables=None): def __init__(self, input, conditional_variables=None):
@@ -60,14 +61,14 @@ class DataCondition(ConditionInterface):
variables (if any). variables (if any).
:param input: Input data for the condition. :param input: Input data for the condition.
:type input: torch.Tensor or Graph or Data :type input: torch.Tensor or Graph or torch_geometric.data.Data
:param conditional_variables: Conditional variables for the condition. :param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor or LabelTensor :type conditional_variables: torch.Tensor or LabelTensor
.. note:: .. note::
If either `input` is composed by a list of :class:`Graph`/ If either `input` is composed by a list of :class:`Graph`/
:class:`Data` objects, all elements must have the same structure :class:`torch_geometric.data.Data` objects, all elements must have
(keys and data types) the same structure (keys and data types)
""" """
super().__init__() super().__init__()
self.input = input self.input = input

View File

@@ -33,7 +33,8 @@ class InputEquationCondition(ConditionInterface):
:rtype: InputTensorEquationCondition | InputGraphEquationCondition :rtype: InputTensorEquationCondition | InputGraphEquationCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`, :raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`. :class:`LabelTensor`, :class:`Graph`, or
:class:`torch_geometric.data.Data`.
""" """
# If the class is already a subclass, return the instance # If the class is already a subclass, return the instance
@@ -59,16 +60,16 @@ class InputEquationCondition(ConditionInterface):
""" """
Initialize the InputEquationCondition by storing the input and equation. Initialize the InputEquationCondition by storing the input and equation.
:param input: torch.Tensor or Graph/Data object containing the input. :param input: Input data for the condition.
:type input: torch.Tensor | Graph :type input: torch.Tensor | Graph
:param EquationInterface equation: Equation object containing the :param EquationInterface equation: Equation object containing the
equation function. equation function.
.. note:: .. note::
If ``input`` is composed by a list of :class:`Graph`/:class:`Data` If ``input`` is composed by a list of :class:`Graph`/
objects, all elements must have the same structure (keys and data :class:`torch_geometric.data.Data` objects, all elements must have
types). Moreover, at least one attribute must be a the same structure (keys and data types). Moreover, at least one
:class:`LabelTensor`. attribute must be a :class:`LabelTensor`.
""" """
super().__init__() super().__init__()
@@ -103,7 +104,7 @@ class InputGraphEquationCondition(InputEquationCondition):
Check if at least one LabelTensor is present in the Graph object. Check if at least one LabelTensor is present in the Graph object.
:param input: Input data. :param input: Input data.
:type input: torch.Tensor or Graph or Data :type input: torch.Tensor | Graph | torch_geometric.data.Data
:raises ValueError: If the input data object does not contain at least :raises ValueError: If the input data object does not contain at least
one LabelTensor. one LabelTensor.

View File

@@ -25,19 +25,20 @@ class InputTargetCondition(ConditionInterface):
the types of input and target data. the types of input and target data.
:param input: Input data for the condition. :param input: Input data for the condition.
:type input: torch.Tensor | Graph | Data | list | tuple :type input: torch.Tensor | Graph | torch_geometric.data.Data | list | \
tuple
:param target: Target data for the condition. :param target: Target data for the condition.
Graph, Data, or list/tuple. :type target: torch.Tensor | Graph | torch_geometric.data.Data | list \
:type target: torch.Tensor | Graph | Data | list | tuple | tuple
:return: Subclass of InputTargetCondition :return: Subclass of InputTargetCondition
:rtype: TensorInputTensorTargetCondition | :rtype: TensorInputTensorTargetCondition | \
TensorInputGraphTargetCondition | TensorInputGraphTargetCondition | \
GraphInputTensorTargetCondition | GraphInputTensorTargetCondition | \
GraphInputGraphTargetCondition GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type :raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or :class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
:class:`Data`. :class:`torch_geometric.data.Data`.
""" """
if cls != InputTargetCondition: if cls != InputTargetCondition:
return super().__new__(cls) return super().__new__(cls)
@@ -71,23 +72,23 @@ class InputTargetCondition(ConditionInterface):
raise ValueError( raise ValueError(
"Invalid input/target types. " "Invalid input/target types. "
"Please provide either Data, Graph, LabelTensor or torch.Tensor " "Please provide either torch_geometric.data.Data, Graph, LabelTensor "
"objects." "or torch.Tensor objects."
) )
def __init__(self, input, target): def __init__(self, input, target):
""" """
Initialize the InputTargetCondition, storing the input and target data. Initialize the InputTargetCondition, storing the input and target data.
:param input: torch.Tensor or Graph/Data object containing the input :param input: Input data for the condition.
:type input: torch.Tensor | Graph or Data :type input: torch.Tensor | Graph | torch_geometric.data.Data
:param target: torch.Tensor or Graph/Data object containing the target :param target: Target data for the condition.
:type target: torch.Tensor or Graph or Data :type target: torch.Tensor | Graph | torch_geometric.data.Data
.. note:: .. note::
If either ``input`` or ``target`` are composed by a list of If either ``input`` or ``target`` are composed by a list of
:class:`Graph`/:class:`Data` objects, all elements must have the :class:`Graph`/:class:`torch_geometric.data.Data` objects, all
same structure (keys and data types) elements must have the same structure (keys and data types)
""" """
super().__init__() super().__init__()
@@ -117,19 +118,20 @@ class TensorInputTensorTargetCondition(InputTargetCondition):
class TensorInputGraphTargetCondition(InputTargetCondition): class TensorInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor` InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
input and :class:`Graph`/:class:`Data` target data. input and :class:`Graph`/:class:`torch_geometric.data.Data` target data.
""" """
class GraphInputTensorTargetCondition(InputTargetCondition): class GraphInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and InputTargetCondition subclass for :class:`Graph`/
:class:`torch.Tensor`/:class:`LabelTensor` target data. :class:`torch_geometric.data.Data` input and :class:`torch.Tensor`/
:class:`LabelTensor` target data.
""" """
class GraphInputGraphTargetCondition(InputTargetCondition): class GraphInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and InputTargetCondition subclass for :class:`Graph`/
target data. :class:`torch_geometric.data.Data` input and target data.
""" """