diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 36f4011..ce339be 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -1,5 +1,5 @@ """ -Module for conditions. +Module for importing Conditions objects. """ __all__ = [ diff --git a/pina/condition/condition.py b/pina/condition/condition.py index 53744b4..ccb5d14 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -1,4 +1,6 @@ -"""Condition module.""" +""" +Condition module. +""" import warnings from .data_condition import DataCondition @@ -13,12 +15,11 @@ warnings.filterwarnings("always", category=DeprecationWarning) def warning_function(new, old): - """Handle the deprecation warning. + """ + Handle the deprecation warning. - :param new: Object to use instead of the old one. - :type new: str - :param old: Object to deprecate. - :type old: str + :param str new: Object to use instead of the old one. + :param str old: Object to deprecate. """ warnings.warn( f"'{old}' is deprecated and will be removed " @@ -72,7 +73,6 @@ class Condition: ... input=data, ... conditional_variables=conditional_variables ... ) - """ __slots__ = list( @@ -85,7 +85,19 @@ class Condition: ) def __new__(cls, *args, **kwargs): + """ + Create a new condition object based on the keyword arguments passed. + - ``input`` and ``target``: :class:`InputTargetCondition` + - ``domain`` and ``equation``: :class:`DomainEquationCondition` + - ``input`` and ``equation``: :class:`InputEquationCondition` + - ``input``: :class:`DataCondition` + - ``input`` and ``conditional_variables``: :class:`DataCondition` + :raises ValueError: No valid condition has been found. + :return: A new condition instance belonging to the proper class. + :rtype: ConditionInputTarget | ConditionInputEquation | + ConditionDomainEquation | ConditionData + """ if len(args) != 0: raise ValueError( "Condition takes only the following keyword " diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index 4d748c3..382eb75 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -21,17 +21,36 @@ class ConditionInterface(metaclass=ABCMeta): """ Return the problem to which the condition is associated. - :return: Problem to which the condition is associated + :return: Problem to which the condition is associated. :rtype: pina.problem.AbstractProblem """ + return self._problem @problem.setter def problem(self, value): + """ + Set the problem to which the condition is associated. + + :param value: Problem to which the condition is associated. + :type value: pina.problem.AbstractProblem + """ + self._problem = value @staticmethod def _check_graph_list_consistency(data_list): + """ + Check if the list of Data/Graph objects is consistent. + + :param data_list: list of Data/Graph objects. + :type data_list: list(Data) | list(Graph) + + :raises ValueError: Input data must be either Data or Graph objects. + :raises ValueError: All elements in the list must have the same keys. + :raises ValueError: Type mismatch in data tensors. + :raises ValueError: Label mismatch in LabelTensors. + """ # If the data is a Graph or Data object, return (do not need to check # anything) diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 1560157..b43bfd1 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -23,17 +23,20 @@ class DataCondition(ConditionInterface): def __new__(cls, input, conditional_variables=None): """ - Instanciate the correct subclass of DataCondition by checking the type - of the input data (input and conditional_variables). + Instantiate the appropriate subclass of DataCondition based on the + types of input data. + + :param input: Input data for the condition. + :type input: torch.Tensor | LabelTensor | Graph | Data + :param conditional_variables: Conditional variables for the condition. + :type conditional_variables: torch.Tensor | LabelTensor + :return: Subclass of DataCondition. + :rtype: TensorDataCondition | GraphDataCondition + + :raises ValueError: If input is not of type :class:`torch.Tensor`, + :class:`LabelTensor`, :class:`Graph`, or :class:`Data`. + - :param input: torch.Tensor or Graph/Data object containing the input - data - :type input: torch.Tensor or Graph or Data - :param conditional_variables: torch.Tensor or LabelTensor containing - the conditional variables - :type conditional_variables: torch.Tensor or LabelTensor - :return: DataCondition subclass - :rtype: TensorDataCondition or GraphDataCondition """ if cls != DataCondition: return super().__new__(cls) @@ -56,12 +59,15 @@ class DataCondition(ConditionInterface): Initialize the DataCondition, storing the input and conditional variables (if any). - :param input: torch.Tensor or Graph/Data object containing the input - data + :param input: Input data for the condition. :type input: torch.Tensor or Graph or Data - :param conditional_variables: torch.Tensor or LabelTensor containing - the conditional variables + :param conditional_variables: Conditional variables for the condition. :type conditional_variables: torch.Tensor or LabelTensor + + .. note:: + If either `input` is composed by a list of :class:`Graph`/ + :class:`Data` objects, all elements must have the same structure + (keys and data types) """ super().__init__() self.input = input diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index aad9d9f..58f9a66 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -20,9 +20,9 @@ class DomainEquationCondition(ConditionInterface): """ Initialize the DomainEquationCondition, storing the domain and equation. - :param DomainInterface domain: Domain object containing the domain data + :param DomainInterface domain: Domain object containing the domain data. :param EquationInterface equation: Equation object containing the - equation data + equation data. """ super().__init__() self.domain = domain diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 9a267a3..0d0ebc7 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -22,15 +22,18 @@ class InputEquationCondition(ConditionInterface): def __new__(cls, input, equation): """ - Instanciate the correct subclass of InputEquationCondition by checking - the type of the input data (only `input`). + Instantiate the appropriate subclass of InputEquationCondition based on + the type of input data. - :param input: torch.Tensor or Graph/Data object containing the input - :type input: torch.Tensor or Graph or Data + :param input: Input data. It can be a LabelTensor or a Graph object. + :type input: LabelTensor | Graph :param EquationInterface equation: Equation object containing the - equation function - :return: InputEquationCondition subclass - :rtype: InputTensorEquationCondition or InputGraphEquationCondition + equation function. + :return: Subclass of InputEquationCondition, based on the input type. + :rtype: InputTensorEquationCondition | InputGraphEquationCondition + + :raises ValueError: If input is not of type :class:`torch.Tensor`, + :class:`LabelTensor`, :class:`Graph`, or :class:`Data`. """ # If the class is already a subclass, return the instance @@ -56,11 +59,18 @@ class InputEquationCondition(ConditionInterface): """ Initialize the InputEquationCondition by storing the input and equation. - :param input: torch.Tensor or Graph/Data object containing the input - :type input: torch.Tensor or Graph or Data + :param input: torch.Tensor or Graph/Data object containing the input. + :type input: torch.Tensor | Graph :param EquationInterface equation: Equation object containing the - equation function + equation function. + + .. note:: + If ``input`` is composed by a list of :class:`Graph`/:class:`Data` + objects, all elements must have the same structure (keys and data + types). Moreover, at least one attribute must be a + :class:`LabelTensor`. """ + super().__init__() self.input = input self.equation = equation @@ -90,11 +100,15 @@ class InputGraphEquationCondition(InputEquationCondition): @staticmethod def _check_label_tensor(input): """ - Check if the input is a LabelTensor. + Check if at least one LabelTensor is present in the Graph object. - :param input: input data + :param input: Input data. :type input: torch.Tensor or Graph or Data + + :raises ValueError: If the input data object does not contain at least + one LabelTensor. """ + # Store the fist element of the list/tuple if input is a list/tuple # it is anougth to check the first element because all elements must # have the same type and structure (already checked) diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 70e09bc..b8d2dbf 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -21,17 +21,23 @@ class InputTargetCondition(ConditionInterface): def __new__(cls, input, target): """ - Instanciate the correct subclass of InputTargetCondition by checking the - type of the input and target data. + Instantiate the appropriate subclass of InputTargetCondition based on + the types of input and target data. - :param input: torch.Tensor or Graph/Data object containing the input - :type input: torch.Tensor or Graph or Data - :param target: torch.Tensor or Graph/Data object containing the target - :type target: torch.Tensor or Graph or Data - :return: InputTargetCondition subclass - :rtype: TensorInputTensorTargetCondition or - TensorInputGraphTargetCondition or GraphInputTensorTargetCondition - or GraphInputGraphTargetCondition + :param input: Input data for the condition. + :type input: torch.Tensor | Graph | Data | list | tuple + :param target: Target data for the condition. + Graph, Data, or list/tuple. + :type target: torch.Tensor | Graph | Data | list | tuple + :return: Subclass of InputTargetCondition + :rtype: TensorInputTensorTargetCondition | + TensorInputGraphTargetCondition | + GraphInputTensorTargetCondition | + GraphInputGraphTargetCondition + + :raises ValueError: If input and or target are not of type + :class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or + :class:`Data`. """ if cls != InputTargetCondition: return super().__new__(cls) @@ -74,10 +80,16 @@ class InputTargetCondition(ConditionInterface): Initialize the InputTargetCondition, storing the input and target data. :param input: torch.Tensor or Graph/Data object containing the input - :type input: torch.Tensor or Graph or Data + :type input: torch.Tensor | Graph or Data :param target: torch.Tensor or Graph/Data object containing the target :type target: torch.Tensor or Graph or Data + + .. note:: + If either ``input`` or ``target`` are composed by a list of + :class:`Graph`/:class:`Data` objects, all elements must have the + same structure (keys and data types) """ + super().__init__() self._check_input_target_len(input, target) self.input = input @@ -97,25 +109,27 @@ class InputTargetCondition(ConditionInterface): class TensorInputTensorTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for torch.Tensor input and target data. + InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor` + input and target data. """ class TensorInputGraphTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for torch.Tensor input and Graph/Data target - data. + InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor` + input and :class:`Graph`/:class:`Data` target data. """ class GraphInputTensorTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for Graph/Data input and torch.Tensor target - data. + InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and + :class:`torch.Tensor`/:class:`LabelTensor` target data. """ class GraphInputGraphTargetCondition(InputTargetCondition): """ - InputTargetCondition subclass for Graph/Data input and target data. + InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and + target data. """