Doc conditions

This commit is contained in:
FilippoOlivo
2025-03-11 15:22:35 +01:00
parent 92bb04fafe
commit be0e39a050
7 changed files with 119 additions and 54 deletions

View File

@@ -1,5 +1,5 @@
""" """
Module for conditions. Module for importing Conditions objects.
""" """
__all__ = [ __all__ = [

View File

@@ -1,4 +1,6 @@
"""Condition module.""" """
Condition module.
"""
import warnings import warnings
from .data_condition import DataCondition from .data_condition import DataCondition
@@ -13,12 +15,11 @@ warnings.filterwarnings("always", category=DeprecationWarning)
def warning_function(new, old): def warning_function(new, old):
"""Handle the deprecation warning. """
Handle the deprecation warning.
:param new: Object to use instead of the old one. :param str new: Object to use instead of the old one.
:type new: str :param str old: Object to deprecate.
:param old: Object to deprecate.
:type old: str
""" """
warnings.warn( warnings.warn(
f"'{old}' is deprecated and will be removed " f"'{old}' is deprecated and will be removed "
@@ -72,7 +73,6 @@ class Condition:
... input=data, ... input=data,
... conditional_variables=conditional_variables ... conditional_variables=conditional_variables
... ) ... )
""" """
__slots__ = list( __slots__ = list(
@@ -85,7 +85,19 @@ class Condition:
) )
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
"""
Create a new condition object based on the keyword arguments passed.
- ``input`` and ``target``: :class:`InputTargetCondition`
- ``domain`` and ``equation``: :class:`DomainEquationCondition`
- ``input`` and ``equation``: :class:`InputEquationCondition`
- ``input``: :class:`DataCondition`
- ``input`` and ``conditional_variables``: :class:`DataCondition`
:raises ValueError: No valid condition has been found.
:return: A new condition instance belonging to the proper class.
:rtype: ConditionInputTarget | ConditionInputEquation |
ConditionDomainEquation | ConditionData
"""
if len(args) != 0: if len(args) != 0:
raise ValueError( raise ValueError(
"Condition takes only the following keyword " "Condition takes only the following keyword "

View File

@@ -21,17 +21,36 @@ class ConditionInterface(metaclass=ABCMeta):
""" """
Return the problem to which the condition is associated. Return the problem to which the condition is associated.
:return: Problem to which the condition is associated :return: Problem to which the condition is associated.
:rtype: pina.problem.AbstractProblem :rtype: pina.problem.AbstractProblem
""" """
return self._problem return self._problem
@problem.setter @problem.setter
def problem(self, value): def problem(self, value):
"""
Set the problem to which the condition is associated.
:param value: Problem to which the condition is associated.
:type value: pina.problem.AbstractProblem
"""
self._problem = value self._problem = value
@staticmethod @staticmethod
def _check_graph_list_consistency(data_list): def _check_graph_list_consistency(data_list):
"""
Check if the list of Data/Graph objects is consistent.
:param data_list: list of Data/Graph objects.
:type data_list: list(Data) | list(Graph)
:raises ValueError: Input data must be either Data or Graph objects.
:raises ValueError: All elements in the list must have the same keys.
:raises ValueError: Type mismatch in data tensors.
:raises ValueError: Label mismatch in LabelTensors.
"""
# If the data is a Graph or Data object, return (do not need to check # If the data is a Graph or Data object, return (do not need to check
# anything) # anything)

View File

@@ -23,17 +23,20 @@ class DataCondition(ConditionInterface):
def __new__(cls, input, conditional_variables=None): def __new__(cls, input, conditional_variables=None):
""" """
Instanciate the correct subclass of DataCondition by checking the type Instantiate the appropriate subclass of DataCondition based on the
of the input data (input and conditional_variables). types of input data.
:param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor
:return: Subclass of DataCondition.
:rtype: TensorDataCondition | GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
:param input: torch.Tensor or Graph/Data object containing the input
data
:type input: torch.Tensor or Graph or Data
:param conditional_variables: torch.Tensor or LabelTensor containing
the conditional variables
:type conditional_variables: torch.Tensor or LabelTensor
:return: DataCondition subclass
:rtype: TensorDataCondition or GraphDataCondition
""" """
if cls != DataCondition: if cls != DataCondition:
return super().__new__(cls) return super().__new__(cls)
@@ -56,12 +59,15 @@ class DataCondition(ConditionInterface):
Initialize the DataCondition, storing the input and conditional Initialize the DataCondition, storing the input and conditional
variables (if any). variables (if any).
:param input: torch.Tensor or Graph/Data object containing the input :param input: Input data for the condition.
data
:type input: torch.Tensor or Graph or Data :type input: torch.Tensor or Graph or Data
:param conditional_variables: torch.Tensor or LabelTensor containing :param conditional_variables: Conditional variables for the condition.
the conditional variables
:type conditional_variables: torch.Tensor or LabelTensor :type conditional_variables: torch.Tensor or LabelTensor
.. note::
If either `input` is composed by a list of :class:`Graph`/
:class:`Data` objects, all elements must have the same structure
(keys and data types)
""" """
super().__init__() super().__init__()
self.input = input self.input = input

View File

@@ -20,9 +20,9 @@ class DomainEquationCondition(ConditionInterface):
""" """
Initialize the DomainEquationCondition, storing the domain and equation. Initialize the DomainEquationCondition, storing the domain and equation.
:param DomainInterface domain: Domain object containing the domain data :param DomainInterface domain: Domain object containing the domain data.
:param EquationInterface equation: Equation object containing the :param EquationInterface equation: Equation object containing the
equation data equation data.
""" """
super().__init__() super().__init__()
self.domain = domain self.domain = domain

View File

@@ -22,15 +22,18 @@ class InputEquationCondition(ConditionInterface):
def __new__(cls, input, equation): def __new__(cls, input, equation):
""" """
Instanciate the correct subclass of InputEquationCondition by checking Instantiate the appropriate subclass of InputEquationCondition based on
the type of the input data (only `input`). the type of input data.
:param input: torch.Tensor or Graph/Data object containing the input :param input: Input data. It can be a LabelTensor or a Graph object.
:type input: torch.Tensor or Graph or Data :type input: LabelTensor | Graph
:param EquationInterface equation: Equation object containing the :param EquationInterface equation: Equation object containing the
equation function equation function.
:return: InputEquationCondition subclass :return: Subclass of InputEquationCondition, based on the input type.
:rtype: InputTensorEquationCondition or InputGraphEquationCondition :rtype: InputTensorEquationCondition | InputGraphEquationCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
""" """
# If the class is already a subclass, return the instance # If the class is already a subclass, return the instance
@@ -56,11 +59,18 @@ class InputEquationCondition(ConditionInterface):
""" """
Initialize the InputEquationCondition by storing the input and equation. Initialize the InputEquationCondition by storing the input and equation.
:param input: torch.Tensor or Graph/Data object containing the input :param input: torch.Tensor or Graph/Data object containing the input.
:type input: torch.Tensor or Graph or Data :type input: torch.Tensor | Graph
:param EquationInterface equation: Equation object containing the :param EquationInterface equation: Equation object containing the
equation function equation function.
.. note::
If ``input`` is composed by a list of :class:`Graph`/:class:`Data`
objects, all elements must have the same structure (keys and data
types). Moreover, at least one attribute must be a
:class:`LabelTensor`.
""" """
super().__init__() super().__init__()
self.input = input self.input = input
self.equation = equation self.equation = equation
@@ -90,11 +100,15 @@ class InputGraphEquationCondition(InputEquationCondition):
@staticmethod @staticmethod
def _check_label_tensor(input): def _check_label_tensor(input):
""" """
Check if the input is a LabelTensor. Check if at least one LabelTensor is present in the Graph object.
:param input: input data :param input: Input data.
:type input: torch.Tensor or Graph or Data :type input: torch.Tensor or Graph or Data
:raises ValueError: If the input data object does not contain at least
one LabelTensor.
""" """
# Store the fist element of the list/tuple if input is a list/tuple # Store the fist element of the list/tuple if input is a list/tuple
# it is anougth to check the first element because all elements must # it is anougth to check the first element because all elements must
# have the same type and structure (already checked) # have the same type and structure (already checked)

View File

@@ -21,17 +21,23 @@ class InputTargetCondition(ConditionInterface):
def __new__(cls, input, target): def __new__(cls, input, target):
""" """
Instanciate the correct subclass of InputTargetCondition by checking the Instantiate the appropriate subclass of InputTargetCondition based on
type of the input and target data. the types of input and target data.
:param input: torch.Tensor or Graph/Data object containing the input :param input: Input data for the condition.
:type input: torch.Tensor or Graph or Data :type input: torch.Tensor | Graph | Data | list | tuple
:param target: torch.Tensor or Graph/Data object containing the target :param target: Target data for the condition.
:type target: torch.Tensor or Graph or Data Graph, Data, or list/tuple.
:return: InputTargetCondition subclass :type target: torch.Tensor | Graph | Data | list | tuple
:rtype: TensorInputTensorTargetCondition or :return: Subclass of InputTargetCondition
TensorInputGraphTargetCondition or GraphInputTensorTargetCondition :rtype: TensorInputTensorTargetCondition |
or GraphInputGraphTargetCondition TensorInputGraphTargetCondition |
GraphInputTensorTargetCondition |
GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
:class:`Data`.
""" """
if cls != InputTargetCondition: if cls != InputTargetCondition:
return super().__new__(cls) return super().__new__(cls)
@@ -74,10 +80,16 @@ class InputTargetCondition(ConditionInterface):
Initialize the InputTargetCondition, storing the input and target data. Initialize the InputTargetCondition, storing the input and target data.
:param input: torch.Tensor or Graph/Data object containing the input :param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data :type input: torch.Tensor | Graph or Data
:param target: torch.Tensor or Graph/Data object containing the target :param target: torch.Tensor or Graph/Data object containing the target
:type target: torch.Tensor or Graph or Data :type target: torch.Tensor or Graph or Data
.. note::
If either ``input`` or ``target`` are composed by a list of
:class:`Graph`/:class:`Data` objects, all elements must have the
same structure (keys and data types)
""" """
super().__init__() super().__init__()
self._check_input_target_len(input, target) self._check_input_target_len(input, target)
self.input = input self.input = input
@@ -97,25 +109,27 @@ class InputTargetCondition(ConditionInterface):
class TensorInputTensorTargetCondition(InputTargetCondition): class TensorInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for torch.Tensor input and target data. InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
input and target data.
""" """
class TensorInputGraphTargetCondition(InputTargetCondition): class TensorInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for torch.Tensor input and Graph/Data target InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
data. input and :class:`Graph`/:class:`Data` target data.
""" """
class GraphInputTensorTargetCondition(InputTargetCondition): class GraphInputTensorTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for Graph/Data input and torch.Tensor target InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
data. :class:`torch.Tensor`/:class:`LabelTensor` target data.
""" """
class GraphInputGraphTargetCondition(InputTargetCondition): class GraphInputGraphTargetCondition(InputTargetCondition):
""" """
InputTargetCondition subclass for Graph/Data input and target data. InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
target data.
""" """