diff --git a/pina/condition/condition.py b/pina/condition/condition.py index d44f124..0e8bd34 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -98,11 +98,11 @@ class Condition: - `input`: :class:`~pina.condition.data_condition.DataCondition` - `input` and `conditional_variables`: :class:`~pina.condition.data_condition.DataCondition` - - :raises ValueError: No valid condition has been found. :return: A new condition instance belonging to the proper class. :rtype: InputTargetCondition | DomainEquationCondition | InputEquationCondition | DataCondition + + :raises ValueError: No valid condition has been found. """ if len(args) != 0: raise ValueError( diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 2d776a2..4643b2a 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -11,9 +11,8 @@ from ..graph import Graph class DataCondition(ConditionInterface): """ - This condition must be used every time a Unsupervised Loss is needed in - the Solver. The `conditional_variable` can be passed as extra-input when - the model learns a conditional distribution. + Condition defined by input data and conditional variables. It can be used + in unsupervised learning problems. """ __slots__ = ["input", "conditional_variables"] @@ -22,14 +21,14 @@ class DataCondition(ConditionInterface): def __new__(cls, input, conditional_variables=None): """ - Instantiate the appropriate subclass of DataCondition based on the - types of input data. + Instantiate the appropriate subclass of :class:`DataCondition` based on + the type of `input`. :param input: Input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] :param conditional_variables: Conditional variables for the condition. - :type conditional_variables: torch.Tensor | LabelTensor + :type conditional_variables: torch.Tensor | LabelTensor, optional :return: Subclass of DataCondition. :rtype: pina.condition.data_condition.TensorDataCondition | pina.condition.data_condition.GraphDataCondition @@ -37,9 +36,8 @@ class DataCondition(ConditionInterface): :raises ValueError: If input is not of type :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. - - """ + if cls != DataCondition: return super().__new__(cls) if isinstance(input, (torch.Tensor, LabelTensor)): @@ -69,8 +67,8 @@ class DataCondition(ConditionInterface): .. note:: If either `input` is composed by a list of - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` - objects, all elements must have the same structure (keys and data + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`, + all elements must have the same structure (keys and data types) """ diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 2ec9489..92cffff 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -10,15 +10,16 @@ from ..equation.equation_interface import EquationInterface class DomainEquationCondition(ConditionInterface): """ - Condition for domain/equation data. This condition must be used every - time a Physics Informed Loss is needed in the Solver. + Condition defined by a domain and an equation. It can be used in Physics + Informed problems. Before using this condition, make sure that input data + are correctly sampled from the domain. """ __slots__ = ["domain", "equation"] def __init__(self, domain, equation): """ - Initialize the object by storing the domain and equation. + Initialise the object by storing the domain and equation. :param DomainInterface domain: Domain object containing the domain data. :param EquationInterface equation: Equation object containing the diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 383eb3e..db78a80 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -12,8 +12,8 @@ from ..equation.equation_interface import EquationInterface class InputEquationCondition(ConditionInterface): """ - Condition for input/equation data. This condition must be used every - time a Physics Informed Loss is needed in the Solver. + Condition defined by input data and an equation. This condition can be + used in a Physics Informed problems. """ __slots__ = ["input", "equation"] @@ -22,10 +22,10 @@ class InputEquationCondition(ConditionInterface): def __new__(cls, input, equation): """ - Instantiate the appropriate subclass of InputEquationCondition based on - the type of input data. + Instantiate the appropriate subclass of :class:`InputEquationCondition` + based on the type of `input`. - :param input: Input data. It can be a LabelTensor or a Graph object. + :param input: Input data for the condition. :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] :param EquationInterface equation: Equation object containing the equation function. diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 11070fe..6d4c524 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -11,8 +11,8 @@ from .condition_interface import ConditionInterface class InputTargetCondition(ConditionInterface): """ - Condition for domain/equation data. This condition must be used every - time a Physics Informed or a Supervised Loss is needed in the Solver. + Condition defined by input and target data. This condition can be used in + both supervised learning and Physics-informed problems. """ __slots__ = ["input", "target"] @@ -25,15 +25,11 @@ class InputTargetCondition(ConditionInterface): the types of input and target data. :param input: Input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | - torch_geometric.data.Data | list[Graph] | - list[torch_geometric.data.Data] | tuple[Graph] | - tuple[torch_geometric.data.Data] + :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] :param target: Target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | - torch_geometric.data.Data | list[Graph] | - list[torch_geometric.data.Data] | tuple[Graph] | - tuple[torch_geometric.data.Data] + :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] :return: Subclass of InputTargetCondition :rtype: pina.condition.input_target_condition. TensorInputTensorTargetCondition | @@ -43,7 +39,7 @@ class InputTargetCondition(ConditionInterface): GraphInputTensorTargetCondition | pina.condition.input_target_condition.GraphInputGraphTargetCondition - :raises ValueError: If input and or target are not of type + :raises ValueError: If `input` and/or `target` are not of type :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. """ @@ -85,12 +81,11 @@ class InputTargetCondition(ConditionInterface): def __init__(self, input, target): """ - Initialize the object storing the input and target data. + Initialize the object by storing the `input` and `target` data. :param input: Input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | - tuple[Data] + list[Data] | tuple[Graph] | tuple[Data] :param target: Target data for the condition. :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] @@ -122,15 +117,15 @@ class InputTargetCondition(ConditionInterface): class TensorInputTensorTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` input and target data. + :class:`~pina.label_tensor.LabelTensor` `input` and `target` data. """ class TensorInputGraphTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` input and - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target + :class:`~pina.label_tensor.LabelTensor` `input` and + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target` data. """ @@ -138,13 +133,13 @@ class TensorInputGraphTargetCondition(InputTargetCondition): class GraphInputTensorTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`~pina.graph.Graph` o - :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` target data. + :class:`~torch_geometric.data.Data` `input` and :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor` `target` data. """ class GraphInputGraphTargetCondition(InputTargetCondition): """ InputTargetCondition subclass for :class:`~pina.graph.Graph`/ - :class:`~torch_geometric.data.Data` input and target data. + :class:`~torch_geometric.data.Data` `input` and `target` data. """