Doc conditions
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Module for conditions.
|
||||
Module for importing Conditions objects.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Condition module."""
|
||||
"""
|
||||
Condition module.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from .data_condition import DataCondition
|
||||
@@ -13,12 +15,11 @@ warnings.filterwarnings("always", category=DeprecationWarning)
|
||||
|
||||
|
||||
def warning_function(new, old):
|
||||
"""Handle the deprecation warning.
|
||||
"""
|
||||
Handle the deprecation warning.
|
||||
|
||||
:param new: Object to use instead of the old one.
|
||||
:type new: str
|
||||
:param old: Object to deprecate.
|
||||
:type old: str
|
||||
:param str new: Object to use instead of the old one.
|
||||
:param str old: Object to deprecate.
|
||||
"""
|
||||
warnings.warn(
|
||||
f"'{old}' is deprecated and will be removed "
|
||||
@@ -72,7 +73,6 @@ class Condition:
|
||||
... input=data,
|
||||
... conditional_variables=conditional_variables
|
||||
... )
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = list(
|
||||
@@ -85,7 +85,19 @@ class Condition:
|
||||
)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
Create a new condition object based on the keyword arguments passed.
|
||||
- ``input`` and ``target``: :class:`InputTargetCondition`
|
||||
- ``domain`` and ``equation``: :class:`DomainEquationCondition`
|
||||
- ``input`` and ``equation``: :class:`InputEquationCondition`
|
||||
- ``input``: :class:`DataCondition`
|
||||
- ``input`` and ``conditional_variables``: :class:`DataCondition`
|
||||
|
||||
:raises ValueError: No valid condition has been found.
|
||||
:return: A new condition instance belonging to the proper class.
|
||||
:rtype: ConditionInputTarget | ConditionInputEquation |
|
||||
ConditionDomainEquation | ConditionData
|
||||
"""
|
||||
if len(args) != 0:
|
||||
raise ValueError(
|
||||
"Condition takes only the following keyword "
|
||||
|
||||
@@ -21,17 +21,36 @@ class ConditionInterface(metaclass=ABCMeta):
|
||||
"""
|
||||
Return the problem to which the condition is associated.
|
||||
|
||||
:return: Problem to which the condition is associated
|
||||
:return: Problem to which the condition is associated.
|
||||
:rtype: pina.problem.AbstractProblem
|
||||
"""
|
||||
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, value):
|
||||
"""
|
||||
Set the problem to which the condition is associated.
|
||||
|
||||
:param value: Problem to which the condition is associated.
|
||||
:type value: pina.problem.AbstractProblem
|
||||
"""
|
||||
|
||||
self._problem = value
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_list_consistency(data_list):
|
||||
"""
|
||||
Check if the list of Data/Graph objects is consistent.
|
||||
|
||||
:param data_list: list of Data/Graph objects.
|
||||
:type data_list: list(Data) | list(Graph)
|
||||
|
||||
:raises ValueError: Input data must be either Data or Graph objects.
|
||||
:raises ValueError: All elements in the list must have the same keys.
|
||||
:raises ValueError: Type mismatch in data tensors.
|
||||
:raises ValueError: Label mismatch in LabelTensors.
|
||||
"""
|
||||
|
||||
# If the data is a Graph or Data object, return (do not need to check
|
||||
# anything)
|
||||
|
||||
@@ -23,17 +23,20 @@ class DataCondition(ConditionInterface):
|
||||
|
||||
def __new__(cls, input, conditional_variables=None):
|
||||
"""
|
||||
Instanciate the correct subclass of DataCondition by checking the type
|
||||
of the input data (input and conditional_variables).
|
||||
Instantiate the appropriate subclass of DataCondition based on the
|
||||
types of input data.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | LabelTensor | Graph | Data
|
||||
:param conditional_variables: Conditional variables for the condition.
|
||||
:type conditional_variables: torch.Tensor | LabelTensor
|
||||
:return: Subclass of DataCondition.
|
||||
:rtype: TensorDataCondition | GraphDataCondition
|
||||
|
||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
|
||||
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
data
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:param conditional_variables: torch.Tensor or LabelTensor containing
|
||||
the conditional variables
|
||||
:type conditional_variables: torch.Tensor or LabelTensor
|
||||
:return: DataCondition subclass
|
||||
:rtype: TensorDataCondition or GraphDataCondition
|
||||
"""
|
||||
if cls != DataCondition:
|
||||
return super().__new__(cls)
|
||||
@@ -56,12 +59,15 @@ class DataCondition(ConditionInterface):
|
||||
Initialize the DataCondition, storing the input and conditional
|
||||
variables (if any).
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
data
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:param conditional_variables: torch.Tensor or LabelTensor containing
|
||||
the conditional variables
|
||||
:param conditional_variables: Conditional variables for the condition.
|
||||
:type conditional_variables: torch.Tensor or LabelTensor
|
||||
|
||||
.. note::
|
||||
If either `input` is composed by a list of :class:`Graph`/
|
||||
:class:`Data` objects, all elements must have the same structure
|
||||
(keys and data types)
|
||||
"""
|
||||
super().__init__()
|
||||
self.input = input
|
||||
|
||||
@@ -20,9 +20,9 @@ class DomainEquationCondition(ConditionInterface):
|
||||
"""
|
||||
Initialize the DomainEquationCondition, storing the domain and equation.
|
||||
|
||||
:param DomainInterface domain: Domain object containing the domain data
|
||||
:param DomainInterface domain: Domain object containing the domain data.
|
||||
:param EquationInterface equation: Equation object containing the
|
||||
equation data
|
||||
equation data.
|
||||
"""
|
||||
super().__init__()
|
||||
self.domain = domain
|
||||
|
||||
@@ -22,15 +22,18 @@ class InputEquationCondition(ConditionInterface):
|
||||
|
||||
def __new__(cls, input, equation):
|
||||
"""
|
||||
Instanciate the correct subclass of InputEquationCondition by checking
|
||||
the type of the input data (only `input`).
|
||||
Instantiate the appropriate subclass of InputEquationCondition based on
|
||||
the type of input data.
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:param input: Input data. It can be a LabelTensor or a Graph object.
|
||||
:type input: LabelTensor | Graph
|
||||
:param EquationInterface equation: Equation object containing the
|
||||
equation function
|
||||
:return: InputEquationCondition subclass
|
||||
:rtype: InputTensorEquationCondition or InputGraphEquationCondition
|
||||
equation function.
|
||||
:return: Subclass of InputEquationCondition, based on the input type.
|
||||
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
|
||||
|
||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
|
||||
"""
|
||||
|
||||
# If the class is already a subclass, return the instance
|
||||
@@ -56,11 +59,18 @@ class InputEquationCondition(ConditionInterface):
|
||||
"""
|
||||
Initialize the InputEquationCondition by storing the input and equation.
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:param input: torch.Tensor or Graph/Data object containing the input.
|
||||
:type input: torch.Tensor | Graph
|
||||
:param EquationInterface equation: Equation object containing the
|
||||
equation function
|
||||
equation function.
|
||||
|
||||
.. note::
|
||||
If ``input`` is composed by a list of :class:`Graph`/:class:`Data`
|
||||
objects, all elements must have the same structure (keys and data
|
||||
types). Moreover, at least one attribute must be a
|
||||
:class:`LabelTensor`.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.input = input
|
||||
self.equation = equation
|
||||
@@ -90,11 +100,15 @@ class InputGraphEquationCondition(InputEquationCondition):
|
||||
@staticmethod
|
||||
def _check_label_tensor(input):
|
||||
"""
|
||||
Check if the input is a LabelTensor.
|
||||
Check if at least one LabelTensor is present in the Graph object.
|
||||
|
||||
:param input: input data
|
||||
:param input: Input data.
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
|
||||
:raises ValueError: If the input data object does not contain at least
|
||||
one LabelTensor.
|
||||
"""
|
||||
|
||||
# Store the fist element of the list/tuple if input is a list/tuple
|
||||
# it is anougth to check the first element because all elements must
|
||||
# have the same type and structure (already checked)
|
||||
|
||||
@@ -21,17 +21,23 @@ class InputTargetCondition(ConditionInterface):
|
||||
|
||||
def __new__(cls, input, target):
|
||||
"""
|
||||
Instanciate the correct subclass of InputTargetCondition by checking the
|
||||
type of the input and target data.
|
||||
Instantiate the appropriate subclass of InputTargetCondition based on
|
||||
the types of input and target data.
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:param target: torch.Tensor or Graph/Data object containing the target
|
||||
:type target: torch.Tensor or Graph or Data
|
||||
:return: InputTargetCondition subclass
|
||||
:rtype: TensorInputTensorTargetCondition or
|
||||
TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
|
||||
or GraphInputGraphTargetCondition
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | Graph | Data | list | tuple
|
||||
:param target: Target data for the condition.
|
||||
Graph, Data, or list/tuple.
|
||||
:type target: torch.Tensor | Graph | Data | list | tuple
|
||||
:return: Subclass of InputTargetCondition
|
||||
:rtype: TensorInputTensorTargetCondition |
|
||||
TensorInputGraphTargetCondition |
|
||||
GraphInputTensorTargetCondition |
|
||||
GraphInputGraphTargetCondition
|
||||
|
||||
:raises ValueError: If input and or target are not of type
|
||||
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
|
||||
:class:`Data`.
|
||||
"""
|
||||
if cls != InputTargetCondition:
|
||||
return super().__new__(cls)
|
||||
@@ -74,10 +80,16 @@ class InputTargetCondition(ConditionInterface):
|
||||
Initialize the InputTargetCondition, storing the input and target data.
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:type input: torch.Tensor | Graph or Data
|
||||
:param target: torch.Tensor or Graph/Data object containing the target
|
||||
:type target: torch.Tensor or Graph or Data
|
||||
|
||||
.. note::
|
||||
If either ``input`` or ``target`` are composed by a list of
|
||||
:class:`Graph`/:class:`Data` objects, all elements must have the
|
||||
same structure (keys and data types)
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._check_input_target_len(input, target)
|
||||
self.input = input
|
||||
@@ -97,25 +109,27 @@ class InputTargetCondition(ConditionInterface):
|
||||
|
||||
class TensorInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for torch.Tensor input and target data.
|
||||
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
|
||||
input and target data.
|
||||
"""
|
||||
|
||||
|
||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for torch.Tensor input and Graph/Data target
|
||||
data.
|
||||
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
|
||||
input and :class:`Graph`/:class:`Data` target data.
|
||||
"""
|
||||
|
||||
|
||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for Graph/Data input and torch.Tensor target
|
||||
data.
|
||||
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
|
||||
:class:`torch.Tensor`/:class:`LabelTensor` target data.
|
||||
"""
|
||||
|
||||
|
||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for Graph/Data input and target data.
|
||||
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
|
||||
target data.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user