Doc conditions

This commit is contained in:
FilippoOlivo
2025-03-11 15:22:35 +01:00
parent 92bb04fafe
commit be0e39a050
7 changed files with 119 additions and 54 deletions

View File

@@ -1,5 +1,5 @@
"""
Module for conditions.
Module for importing Conditions objects.
"""
__all__ = [

View File

@@ -1,4 +1,6 @@
"""Condition module."""
"""
Condition module.
"""
import warnings
from .data_condition import DataCondition
@@ -13,12 +15,11 @@ warnings.filterwarnings("always", category=DeprecationWarning)
def warning_function(new, old):
"""Handle the deprecation warning.
"""
Handle the deprecation warning.
:param new: Object to use instead of the old one.
:type new: str
:param old: Object to deprecate.
:type old: str
:param str new: Object to use instead of the old one.
:param str old: Object to deprecate.
"""
warnings.warn(
f"'{old}' is deprecated and will be removed "
@@ -72,7 +73,6 @@ class Condition:
... input=data,
... conditional_variables=conditional_variables
... )
"""
__slots__ = list(
@@ -85,7 +85,19 @@ class Condition:
)
def __new__(cls, *args, **kwargs):
"""
Create a new condition object based on the keyword arguments passed.
- ``input`` and ``target``: :class:`InputTargetCondition`
- ``domain`` and ``equation``: :class:`DomainEquationCondition`
- ``input`` and ``equation``: :class:`InputEquationCondition`
- ``input``: :class:`DataCondition`
- ``input`` and ``conditional_variables``: :class:`DataCondition`
:raises ValueError: No valid condition has been found.
:return: A new condition instance belonging to the proper class.
:rtype: ConditionInputTarget | ConditionInputEquation |
ConditionDomainEquation | ConditionData
"""
if len(args) != 0:
raise ValueError(
"Condition takes only the following keyword "

View File

@@ -21,17 +21,36 @@ class ConditionInterface(metaclass=ABCMeta):
"""
Return the problem to which the condition is associated.
:return: Problem to which the condition is associated
:return: Problem to which the condition is associated.
:rtype: pina.problem.AbstractProblem
"""
return self._problem
@problem.setter
def problem(self, value):
"""
Set the problem to which the condition is associated.
:param value: Problem to which the condition is associated.
:type value: pina.problem.AbstractProblem
"""
self._problem = value
@staticmethod
def _check_graph_list_consistency(data_list):
"""
Check if the list of Data/Graph objects is consistent.
:param data_list: list of Data/Graph objects.
:type data_list: list(Data) | list(Graph)
:raises ValueError: Input data must be either Data or Graph objects.
:raises ValueError: All elements in the list must have the same keys.
:raises ValueError: Type mismatch in data tensors.
:raises ValueError: Label mismatch in LabelTensors.
"""
# If the data is a Graph or Data object, return (do not need to check
# anything)

View File

@@ -23,17 +23,20 @@ class DataCondition(ConditionInterface):
def __new__(cls, input, conditional_variables=None):
"""
Instanciate the correct subclass of DataCondition by checking the type
of the input data (input and conditional_variables).
Instantiate the appropriate subclass of DataCondition based on the
types of input data.
:param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor
:return: Subclass of DataCondition.
:rtype: TensorDataCondition | GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
:param input: torch.Tensor or Graph/Data object containing the input
data
:type input: torch.Tensor or Graph or Data
:param conditional_variables: torch.Tensor or LabelTensor containing
the conditional variables
:type conditional_variables: torch.Tensor or LabelTensor
:return: DataCondition subclass
:rtype: TensorDataCondition or GraphDataCondition
"""
if cls != DataCondition:
return super().__new__(cls)
@@ -56,12 +59,15 @@ class DataCondition(ConditionInterface):
Initialize the DataCondition, storing the input and conditional
variables (if any).
:param input: torch.Tensor or Graph/Data object containing the input
data
:param input: Input data for the condition.
:type input: torch.Tensor or Graph or Data
:param conditional_variables: torch.Tensor or LabelTensor containing
the conditional variables
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor or LabelTensor
.. note::
If either `input` is composed by a list of :class:`Graph`/
:class:`Data` objects, all elements must have the same structure
(keys and data types)
"""
super().__init__()
self.input = input

View File

@@ -20,9 +20,9 @@ class DomainEquationCondition(ConditionInterface):
"""
Initialize the DomainEquationCondition, storing the domain and equation.
:param DomainInterface domain: Domain object containing the domain data
:param DomainInterface domain: Domain object containing the domain data.
:param EquationInterface equation: Equation object containing the
equation data
equation data.
"""
super().__init__()
self.domain = domain

View File

@@ -22,15 +22,18 @@ class InputEquationCondition(ConditionInterface):
def __new__(cls, input, equation):
"""
Instanciate the correct subclass of InputEquationCondition by checking
the type of the input data (only `input`).
Instantiate the appropriate subclass of InputEquationCondition based on
the type of input data.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data
:param input: Input data. It can be a LabelTensor or a Graph object.
:type input: LabelTensor | Graph
:param EquationInterface equation: Equation object containing the
equation function
:return: InputEquationCondition subclass
:rtype: InputTensorEquationCondition or InputGraphEquationCondition
equation function.
:return: Subclass of InputEquationCondition, based on the input type.
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
"""
# If the class is already a subclass, return the instance
@@ -56,11 +59,18 @@ class InputEquationCondition(ConditionInterface):
"""
Initialize the InputEquationCondition by storing the input and equation.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data
:param input: torch.Tensor or Graph/Data object containing the input.
:type input: torch.Tensor | Graph
:param EquationInterface equation: Equation object containing the
equation function
equation function.
.. note::
If ``input`` is composed by a list of :class:`Graph`/:class:`Data`
objects, all elements must have the same structure (keys and data
types). Moreover, at least one attribute must be a
:class:`LabelTensor`.
"""
super().__init__()
self.input = input
self.equation = equation
@@ -90,11 +100,15 @@ class InputGraphEquationCondition(InputEquationCondition):
@staticmethod
def _check_label_tensor(input):
"""
Check if the input is a LabelTensor.
Check if at least one LabelTensor is present in the Graph object.
:param input: input data
:param input: Input data.
:type input: torch.Tensor or Graph or Data
:raises ValueError: If the input data object does not contain at least
one LabelTensor.
"""
# Store the fist element of the list/tuple if input is a list/tuple
# it is anougth to check the first element because all elements must
# have the same type and structure (already checked)

View File

@@ -21,17 +21,23 @@ class InputTargetCondition(ConditionInterface):
def __new__(cls, input, target):
"""
Instanciate the correct subclass of InputTargetCondition by checking the
type of the input and target data.
Instantiate the appropriate subclass of InputTargetCondition based on
the types of input and target data.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data
:param target: torch.Tensor or Graph/Data object containing the target
:type target: torch.Tensor or Graph or Data
:return: InputTargetCondition subclass
:rtype: TensorInputTensorTargetCondition or
TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
or GraphInputGraphTargetCondition
:param input: Input data for the condition.
:type input: torch.Tensor | Graph | Data | list | tuple
:param target: Target data for the condition.
Graph, Data, or list/tuple.
:type target: torch.Tensor | Graph | Data | list | tuple
:return: Subclass of InputTargetCondition
:rtype: TensorInputTensorTargetCondition |
TensorInputGraphTargetCondition |
GraphInputTensorTargetCondition |
GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
:class:`Data`.
"""
if cls != InputTargetCondition:
return super().__new__(cls)
@@ -74,10 +80,16 @@ class InputTargetCondition(ConditionInterface):
Initialize the InputTargetCondition, storing the input and target data.
:param input: torch.Tensor or Graph/Data object containing the input
:type input: torch.Tensor or Graph or Data
:type input: torch.Tensor | Graph or Data
:param target: torch.Tensor or Graph/Data object containing the target
:type target: torch.Tensor or Graph or Data
.. note::
If either ``input`` or ``target`` are composed by a list of
:class:`Graph`/:class:`Data` objects, all elements must have the
same structure (keys and data types)
"""
super().__init__()
self._check_input_target_len(input, target)
self.input = input
@@ -97,25 +109,27 @@ class InputTargetCondition(ConditionInterface):
class TensorInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for torch.Tensor input and target data.
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
input and target data.
"""
class TensorInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for torch.Tensor input and Graph/Data target
data.
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
input and :class:`Graph`/:class:`Data` target data.
"""
class GraphInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for Graph/Data input and torch.Tensor target
data.
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
:class:`torch.Tensor`/:class:`LabelTensor` target data.
"""
class GraphInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for Graph/Data input and target data.
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
target data.
"""