From 56e45c6c8775ab45332172722c4806dd962cc992 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 11 Mar 2025 16:26:28 +0100 Subject: [PATCH] Small fixes in conditions --- pina/condition/condition.py | 39 ++++++++++---------- pina/condition/condition_interface.py | 10 +++--- pina/condition/data_condition.py | 21 +++++------ pina/condition/input_equation_condition.py | 15 ++++---- pina/condition/input_target_condition.py | 42 +++++++++++----------- 5 files changed, 67 insertions(+), 60 deletions(-) diff --git a/pina/condition/condition.py b/pina/condition/condition.py index ccb5d14..7a438b9 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -56,23 +56,23 @@ class Condition: Example:: - >>> from pina import Condition - >>> condition = Condition( - ... input=input, - ... target=target - ... ) - >>> condition = Condition( - ... domain=location, - ... equation=equation - ... ) - >>> condition = Condition( - ... input=input, - ... equation=equation - ... ) - >>> condition = Condition( - ... input=data, - ... conditional_variables=conditional_variables - ... ) + >>> from pina import Condition + >>> condition = Condition( + ... input=input, + ... target=target + ... ) + >>> condition = Condition( + ... domain=location, + ... equation=equation + ... ) + >>> condition = Condition( + ... input=input, + ... equation=equation + ... ) + >>> condition = Condition( + ... input=data, + ... conditional_variables=conditional_variables + ... ) """ __slots__ = list( @@ -87,6 +87,7 @@ class Condition: def __new__(cls, *args, **kwargs): """ Create a new condition object based on the keyword arguments passed. + - ``input`` and ``target``: :class:`InputTargetCondition` - ``domain`` and ``equation``: :class:`DomainEquationCondition` - ``input`` and ``equation``: :class:`InputEquationCondition` @@ -95,8 +96,8 @@ class Condition: :raises ValueError: No valid condition has been found. :return: A new condition instance belonging to the proper class. - :rtype: ConditionInputTarget | ConditionInputEquation | - ConditionDomainEquation | ConditionData + :rtype: InputTargetCondition | DomainEquationCondition | + InputEquationCondition | DataCondition """ if len(args) != 0: raise ValueError( diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index 382eb75..b6f4be7 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -41,12 +41,14 @@ class ConditionInterface(metaclass=ABCMeta): @staticmethod def _check_graph_list_consistency(data_list): """ - Check if the list of Data/Graph objects is consistent. + Check if the list of :class:`torch_geometric.data.Data`/:class:`Graph` + objects is consistent. - :param data_list: list of Data/Graph objects. - :type data_list: list(Data) | list(Graph) + :param data_list: List of graph type objects. + :type data_list: list(torch_geometric.data.Data) | list(Graph) - :raises ValueError: Input data must be either Data or Graph objects. + :raises ValueError: Input data must be either torch_geometric.data.Data + or Graph objects. :raises ValueError: All elements in the list must have the same keys. :raises ValueError: Type mismatch in data tensors. :raises ValueError: Label mismatch in LabelTensors. diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index b43bfd1..89fdf22 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -11,10 +11,9 @@ from ..graph import Graph class DataCondition(ConditionInterface): """ - Condition for data. This condition must be used every - time a Unsupervised Loss is needed in the Solver. The conditionalvariable - can be passed as extra-input when the model learns a conditional - distribution + This condition must be used every time a Unsupervised Loss is needed in + the Solver. The conditionalvariable can be passed as extra-input when + the model learns a conditional distribution. """ __slots__ = ["input", "conditional_variables"] @@ -27,14 +26,16 @@ class DataCondition(ConditionInterface): types of input data. :param input: Input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data + :type input: torch.Tensor | LabelTensor | Graph | \ + torch_geometric.data.Data :param conditional_variables: Conditional variables for the condition. :type conditional_variables: torch.Tensor | LabelTensor :return: Subclass of DataCondition. :rtype: TensorDataCondition | GraphDataCondition :raises ValueError: If input is not of type :class:`torch.Tensor`, - :class:`LabelTensor`, :class:`Graph`, or :class:`Data`. + :class:`LabelTensor`, :class:`Graph`, or + :class:`torch_geometric.data.Data`. """ @@ -51,7 +52,7 @@ class DataCondition(ConditionInterface): raise ValueError( "Invalid input types. " - "Please provide either Data or Graph objects." + "Please provide either torch_geometric.data.Data or Graph objects." ) def __init__(self, input, conditional_variables=None): @@ -60,14 +61,14 @@ class DataCondition(ConditionInterface): variables (if any). :param input: Input data for the condition. - :type input: torch.Tensor or Graph or Data + :type input: torch.Tensor or Graph or torch_geometric.data.Data :param conditional_variables: Conditional variables for the condition. :type conditional_variables: torch.Tensor or LabelTensor .. note:: If either `input` is composed by a list of :class:`Graph`/ - :class:`Data` objects, all elements must have the same structure - (keys and data types) + :class:`torch_geometric.data.Data` objects, all elements must have + the same structure (keys and data types) """ super().__init__() self.input = input diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 0d0ebc7..81f4271 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -33,7 +33,8 @@ class InputEquationCondition(ConditionInterface): :rtype: InputTensorEquationCondition | InputGraphEquationCondition :raises ValueError: If input is not of type :class:`torch.Tensor`, - :class:`LabelTensor`, :class:`Graph`, or :class:`Data`. + :class:`LabelTensor`, :class:`Graph`, or + :class:`torch_geometric.data.Data`. """ # If the class is already a subclass, return the instance @@ -59,16 +60,16 @@ class InputEquationCondition(ConditionInterface): """ Initialize the InputEquationCondition by storing the input and equation. - :param input: torch.Tensor or Graph/Data object containing the input. + :param input: Input data for the condition. :type input: torch.Tensor | Graph :param EquationInterface equation: Equation object containing the equation function. .. note:: - If ``input`` is composed by a list of :class:`Graph`/:class:`Data` - objects, all elements must have the same structure (keys and data - types). Moreover, at least one attribute must be a - :class:`LabelTensor`. + If ``input`` is composed by a list of :class:`Graph`/ + :class:`torch_geometric.data.Data` objects, all elements must have + the same structure (keys and data types). Moreover, at least one + attribute must be a :class:`LabelTensor`. """ super().__init__() @@ -103,7 +104,7 @@ class InputGraphEquationCondition(InputEquationCondition): Check if at least one LabelTensor is present in the Graph object. :param input: Input data. - :type input: torch.Tensor or Graph or Data + :type input: torch.Tensor | Graph | torch_geometric.data.Data :raises ValueError: If the input data object does not contain at least one LabelTensor. diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index b8d2dbf..24dcd3e 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -25,19 +25,20 @@ class InputTargetCondition(ConditionInterface): the types of input and target data. :param input: Input data for the condition. - :type input: torch.Tensor | Graph | Data | list | tuple + :type input: torch.Tensor | Graph | torch_geometric.data.Data | list | \ + tuple :param target: Target data for the condition. - Graph, Data, or list/tuple. - :type target: torch.Tensor | Graph | Data | list | tuple + :type target: torch.Tensor | Graph | torch_geometric.data.Data | list \ + | tuple :return: Subclass of InputTargetCondition - :rtype: TensorInputTensorTargetCondition | - TensorInputGraphTargetCondition | - GraphInputTensorTargetCondition | + :rtype: TensorInputTensorTargetCondition | \ + TensorInputGraphTargetCondition | \ + GraphInputTensorTargetCondition | \ GraphInputGraphTargetCondition :raises ValueError: If input and or target are not of type :class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or - :class:`Data`. + :class:`torch_geometric.data.Data`. """ if cls != InputTargetCondition: return super().__new__(cls) @@ -71,23 +72,23 @@ class InputTargetCondition(ConditionInterface): raise ValueError( "Invalid input/target types. " - "Please provide either Data, Graph, LabelTensor or torch.Tensor " - "objects." + "Please provide either torch_geometric.data.Data, Graph, LabelTensor " + "or torch.Tensor objects." ) def __init__(self, input, target): """ Initialize the InputTargetCondition, storing the input and target data. - :param input: torch.Tensor or Graph/Data object containing the input - :type input: torch.Tensor | Graph or Data - :param target: torch.Tensor or Graph/Data object containing the target - :type target: torch.Tensor or Graph or Data + :param input: Input data for the condition. + :type input: torch.Tensor | Graph | torch_geometric.data.Data + :param target: Target data for the condition. + :type target: torch.Tensor | Graph | torch_geometric.data.Data .. note:: If either ``input`` or ``target`` are composed by a list of - :class:`Graph`/:class:`Data` objects, all elements must have the - same structure (keys and data types) + :class:`Graph`/:class:`torch_geometric.data.Data` objects, all + elements must have the same structure (keys and data types) """ super().__init__() @@ -117,19 +118,20 @@ class TensorInputTensorTargetCondition(InputTargetCondition): class TensorInputGraphTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor` - input and :class:`Graph`/:class:`Data` target data. + input and :class:`Graph`/:class:`torch_geometric.data.Data` target data. """ class GraphInputTensorTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and - :class:`torch.Tensor`/:class:`LabelTensor` target data. + InputTargetCondition subclass for :class:`Graph`/ + :class:`torch_geometric.data.Data` input and :class:`torch.Tensor`/ + :class:`LabelTensor` target data. """ class GraphInputGraphTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and - target data. + InputTargetCondition subclass for :class:`Graph`/ + :class:`torch_geometric.data.Data` input and target data. """