Doc conditions
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Module for conditions.
|
Module for importing Conditions objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
"""Condition module."""
|
"""
|
||||||
|
Condition module.
|
||||||
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from .data_condition import DataCondition
|
from .data_condition import DataCondition
|
||||||
@@ -13,12 +15,11 @@ warnings.filterwarnings("always", category=DeprecationWarning)
|
|||||||
|
|
||||||
|
|
||||||
def warning_function(new, old):
|
def warning_function(new, old):
|
||||||
"""Handle the deprecation warning.
|
"""
|
||||||
|
Handle the deprecation warning.
|
||||||
|
|
||||||
:param new: Object to use instead of the old one.
|
:param str new: Object to use instead of the old one.
|
||||||
:type new: str
|
:param str old: Object to deprecate.
|
||||||
:param old: Object to deprecate.
|
|
||||||
:type old: str
|
|
||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"'{old}' is deprecated and will be removed "
|
f"'{old}' is deprecated and will be removed "
|
||||||
@@ -72,7 +73,6 @@ class Condition:
|
|||||||
... input=data,
|
... input=data,
|
||||||
... conditional_variables=conditional_variables
|
... conditional_variables=conditional_variables
|
||||||
... )
|
... )
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = list(
|
__slots__ = list(
|
||||||
@@ -85,7 +85,19 @@ class Condition:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a new condition object based on the keyword arguments passed.
|
||||||
|
- ``input`` and ``target``: :class:`InputTargetCondition`
|
||||||
|
- ``domain`` and ``equation``: :class:`DomainEquationCondition`
|
||||||
|
- ``input`` and ``equation``: :class:`InputEquationCondition`
|
||||||
|
- ``input``: :class:`DataCondition`
|
||||||
|
- ``input`` and ``conditional_variables``: :class:`DataCondition`
|
||||||
|
|
||||||
|
:raises ValueError: No valid condition has been found.
|
||||||
|
:return: A new condition instance belonging to the proper class.
|
||||||
|
:rtype: ConditionInputTarget | ConditionInputEquation |
|
||||||
|
ConditionDomainEquation | ConditionData
|
||||||
|
"""
|
||||||
if len(args) != 0:
|
if len(args) != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Condition takes only the following keyword "
|
"Condition takes only the following keyword "
|
||||||
|
|||||||
@@ -21,17 +21,36 @@ class ConditionInterface(metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Return the problem to which the condition is associated.
|
Return the problem to which the condition is associated.
|
||||||
|
|
||||||
:return: Problem to which the condition is associated
|
:return: Problem to which the condition is associated.
|
||||||
:rtype: pina.problem.AbstractProblem
|
:rtype: pina.problem.AbstractProblem
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._problem
|
return self._problem
|
||||||
|
|
||||||
@problem.setter
|
@problem.setter
|
||||||
def problem(self, value):
|
def problem(self, value):
|
||||||
|
"""
|
||||||
|
Set the problem to which the condition is associated.
|
||||||
|
|
||||||
|
:param value: Problem to which the condition is associated.
|
||||||
|
:type value: pina.problem.AbstractProblem
|
||||||
|
"""
|
||||||
|
|
||||||
self._problem = value
|
self._problem = value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_graph_list_consistency(data_list):
|
def _check_graph_list_consistency(data_list):
|
||||||
|
"""
|
||||||
|
Check if the list of Data/Graph objects is consistent.
|
||||||
|
|
||||||
|
:param data_list: list of Data/Graph objects.
|
||||||
|
:type data_list: list(Data) | list(Graph)
|
||||||
|
|
||||||
|
:raises ValueError: Input data must be either Data or Graph objects.
|
||||||
|
:raises ValueError: All elements in the list must have the same keys.
|
||||||
|
:raises ValueError: Type mismatch in data tensors.
|
||||||
|
:raises ValueError: Label mismatch in LabelTensors.
|
||||||
|
"""
|
||||||
|
|
||||||
# If the data is a Graph or Data object, return (do not need to check
|
# If the data is a Graph or Data object, return (do not need to check
|
||||||
# anything)
|
# anything)
|
||||||
|
|||||||
@@ -23,17 +23,20 @@ class DataCondition(ConditionInterface):
|
|||||||
|
|
||||||
def __new__(cls, input, conditional_variables=None):
|
def __new__(cls, input, conditional_variables=None):
|
||||||
"""
|
"""
|
||||||
Instanciate the correct subclass of DataCondition by checking the type
|
Instantiate the appropriate subclass of DataCondition based on the
|
||||||
of the input data (input and conditional_variables).
|
types of input data.
|
||||||
|
|
||||||
|
:param input: Input data for the condition.
|
||||||
|
:type input: torch.Tensor | LabelTensor | Graph | Data
|
||||||
|
:param conditional_variables: Conditional variables for the condition.
|
||||||
|
:type conditional_variables: torch.Tensor | LabelTensor
|
||||||
|
:return: Subclass of DataCondition.
|
||||||
|
:rtype: TensorDataCondition | GraphDataCondition
|
||||||
|
|
||||||
|
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||||
|
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
|
||||||
|
|
||||||
|
|
||||||
:param input: torch.Tensor or Graph/Data object containing the input
|
|
||||||
data
|
|
||||||
:type input: torch.Tensor or Graph or Data
|
|
||||||
:param conditional_variables: torch.Tensor or LabelTensor containing
|
|
||||||
the conditional variables
|
|
||||||
:type conditional_variables: torch.Tensor or LabelTensor
|
|
||||||
:return: DataCondition subclass
|
|
||||||
:rtype: TensorDataCondition or GraphDataCondition
|
|
||||||
"""
|
"""
|
||||||
if cls != DataCondition:
|
if cls != DataCondition:
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
@@ -56,12 +59,15 @@ class DataCondition(ConditionInterface):
|
|||||||
Initialize the DataCondition, storing the input and conditional
|
Initialize the DataCondition, storing the input and conditional
|
||||||
variables (if any).
|
variables (if any).
|
||||||
|
|
||||||
:param input: torch.Tensor or Graph/Data object containing the input
|
:param input: Input data for the condition.
|
||||||
data
|
|
||||||
:type input: torch.Tensor or Graph or Data
|
:type input: torch.Tensor or Graph or Data
|
||||||
:param conditional_variables: torch.Tensor or LabelTensor containing
|
:param conditional_variables: Conditional variables for the condition.
|
||||||
the conditional variables
|
|
||||||
:type conditional_variables: torch.Tensor or LabelTensor
|
:type conditional_variables: torch.Tensor or LabelTensor
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If either `input` is composed by a list of :class:`Graph`/
|
||||||
|
:class:`Data` objects, all elements must have the same structure
|
||||||
|
(keys and data types)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input = input
|
self.input = input
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ class DomainEquationCondition(ConditionInterface):
|
|||||||
"""
|
"""
|
||||||
Initialize the DomainEquationCondition, storing the domain and equation.
|
Initialize the DomainEquationCondition, storing the domain and equation.
|
||||||
|
|
||||||
:param DomainInterface domain: Domain object containing the domain data
|
:param DomainInterface domain: Domain object containing the domain data.
|
||||||
:param EquationInterface equation: Equation object containing the
|
:param EquationInterface equation: Equation object containing the
|
||||||
equation data
|
equation data.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.domain = domain
|
self.domain = domain
|
||||||
|
|||||||
@@ -22,15 +22,18 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
|
|
||||||
def __new__(cls, input, equation):
|
def __new__(cls, input, equation):
|
||||||
"""
|
"""
|
||||||
Instanciate the correct subclass of InputEquationCondition by checking
|
Instantiate the appropriate subclass of InputEquationCondition based on
|
||||||
the type of the input data (only `input`).
|
the type of input data.
|
||||||
|
|
||||||
:param input: torch.Tensor or Graph/Data object containing the input
|
:param input: Input data. It can be a LabelTensor or a Graph object.
|
||||||
:type input: torch.Tensor or Graph or Data
|
:type input: LabelTensor | Graph
|
||||||
:param EquationInterface equation: Equation object containing the
|
:param EquationInterface equation: Equation object containing the
|
||||||
equation function
|
equation function.
|
||||||
:return: InputEquationCondition subclass
|
:return: Subclass of InputEquationCondition, based on the input type.
|
||||||
:rtype: InputTensorEquationCondition or InputGraphEquationCondition
|
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
|
||||||
|
|
||||||
|
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||||
|
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If the class is already a subclass, return the instance
|
# If the class is already a subclass, return the instance
|
||||||
@@ -56,11 +59,18 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
"""
|
"""
|
||||||
Initialize the InputEquationCondition by storing the input and equation.
|
Initialize the InputEquationCondition by storing the input and equation.
|
||||||
|
|
||||||
:param input: torch.Tensor or Graph/Data object containing the input
|
:param input: torch.Tensor or Graph/Data object containing the input.
|
||||||
:type input: torch.Tensor or Graph or Data
|
:type input: torch.Tensor | Graph
|
||||||
:param EquationInterface equation: Equation object containing the
|
:param EquationInterface equation: Equation object containing the
|
||||||
equation function
|
equation function.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If ``input`` is composed by a list of :class:`Graph`/:class:`Data`
|
||||||
|
objects, all elements must have the same structure (keys and data
|
||||||
|
types). Moreover, at least one attribute must be a
|
||||||
|
:class:`LabelTensor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input = input
|
self.input = input
|
||||||
self.equation = equation
|
self.equation = equation
|
||||||
@@ -90,11 +100,15 @@ class InputGraphEquationCondition(InputEquationCondition):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_label_tensor(input):
|
def _check_label_tensor(input):
|
||||||
"""
|
"""
|
||||||
Check if the input is a LabelTensor.
|
Check if at least one LabelTensor is present in the Graph object.
|
||||||
|
|
||||||
:param input: input data
|
:param input: Input data.
|
||||||
:type input: torch.Tensor or Graph or Data
|
:type input: torch.Tensor or Graph or Data
|
||||||
|
|
||||||
|
:raises ValueError: If the input data object does not contain at least
|
||||||
|
one LabelTensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Store the fist element of the list/tuple if input is a list/tuple
|
# Store the fist element of the list/tuple if input is a list/tuple
|
||||||
# it is anougth to check the first element because all elements must
|
# it is anougth to check the first element because all elements must
|
||||||
# have the same type and structure (already checked)
|
# have the same type and structure (already checked)
|
||||||
|
|||||||
@@ -21,17 +21,23 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
|
|
||||||
def __new__(cls, input, target):
|
def __new__(cls, input, target):
|
||||||
"""
|
"""
|
||||||
Instanciate the correct subclass of InputTargetCondition by checking the
|
Instantiate the appropriate subclass of InputTargetCondition based on
|
||||||
type of the input and target data.
|
the types of input and target data.
|
||||||
|
|
||||||
:param input: torch.Tensor or Graph/Data object containing the input
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor or Graph or Data
|
:type input: torch.Tensor | Graph | Data | list | tuple
|
||||||
:param target: torch.Tensor or Graph/Data object containing the target
|
:param target: Target data for the condition.
|
||||||
:type target: torch.Tensor or Graph or Data
|
Graph, Data, or list/tuple.
|
||||||
:return: InputTargetCondition subclass
|
:type target: torch.Tensor | Graph | Data | list | tuple
|
||||||
:rtype: TensorInputTensorTargetCondition or
|
:return: Subclass of InputTargetCondition
|
||||||
TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
|
:rtype: TensorInputTensorTargetCondition |
|
||||||
or GraphInputGraphTargetCondition
|
TensorInputGraphTargetCondition |
|
||||||
|
GraphInputTensorTargetCondition |
|
||||||
|
GraphInputGraphTargetCondition
|
||||||
|
|
||||||
|
:raises ValueError: If input and or target are not of type
|
||||||
|
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
|
||||||
|
:class:`Data`.
|
||||||
"""
|
"""
|
||||||
if cls != InputTargetCondition:
|
if cls != InputTargetCondition:
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
@@ -74,10 +80,16 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
Initialize the InputTargetCondition, storing the input and target data.
|
Initialize the InputTargetCondition, storing the input and target data.
|
||||||
|
|
||||||
:param input: torch.Tensor or Graph/Data object containing the input
|
:param input: torch.Tensor or Graph/Data object containing the input
|
||||||
:type input: torch.Tensor or Graph or Data
|
:type input: torch.Tensor | Graph or Data
|
||||||
:param target: torch.Tensor or Graph/Data object containing the target
|
:param target: torch.Tensor or Graph/Data object containing the target
|
||||||
:type target: torch.Tensor or Graph or Data
|
:type target: torch.Tensor or Graph or Data
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If either ``input`` or ``target`` are composed by a list of
|
||||||
|
:class:`Graph`/:class:`Data` objects, all elements must have the
|
||||||
|
same structure (keys and data types)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._check_input_target_len(input, target)
|
self._check_input_target_len(input, target)
|
||||||
self.input = input
|
self.input = input
|
||||||
@@ -97,25 +109,27 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
|
|
||||||
class TensorInputTensorTargetCondition(InputTargetCondition):
|
class TensorInputTensorTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for torch.Tensor input and target data.
|
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
|
||||||
|
input and target data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for torch.Tensor input and Graph/Data target
|
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
|
||||||
data.
|
input and :class:`Graph`/:class:`Data` target data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for Graph/Data input and torch.Tensor target
|
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
|
||||||
data.
|
:class:`torch.Tensor`/:class:`LabelTensor` target data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for Graph/Data input and target data.
|
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
|
||||||
|
target data.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user