This commit is contained in:
FilippoOlivo
2025-03-12 12:03:47 +01:00
committed by Nicola Demo
parent 59e6ee595c
commit ae796ce34c
6 changed files with 128 additions and 97 deletions

View File

@@ -32,8 +32,8 @@ class ConditionInterface(metaclass=ABCMeta):
"""
Set the problem to which the condition is associated.
:param value: Problem to which the condition is associated.
:type value: pina.problem.AbstractProblem
:param pina.problem.AbstractProblem value: Problem to which the
condition is associated.
"""
self._problem = value

View File

@@ -12,7 +12,7 @@ from ..graph import Graph
class DataCondition(ConditionInterface):
"""
This condition must be used every time a Unsupervised Loss is needed in
the Solver. The conditionalvariable can be passed as extra-input when
the Solver. The `conditional_variable` can be passed as extra-input when
the model learns a conditional distribution.
"""
@@ -31,7 +31,8 @@ class DataCondition(ConditionInterface):
:param conditional_variables: Conditional variables for the condition.
:type conditional_variables: torch.Tensor | LabelTensor
:return: Subclass of DataCondition.
:rtype: TensorDataCondition | GraphDataCondition
:rtype: pina.condition.data_condition.TensorDataCondition |
pina.condition.data_condition.GraphDataCondition
:raises ValueError: If input is not of type :class:`torch.Tensor`,
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`,

View File

@@ -30,7 +30,9 @@ class InputEquationCondition(ConditionInterface):
:param EquationInterface equation: Equation object containing the
equation function.
:return: Subclass of InputEquationCondition, based on the input type.
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
:rtype: pina.condition.input_equation_condition.
InputTensorEquationCondition |
pina.condition.input_equation_condition.InputGraphEquationCondition
:raises ValueError: If input is not of type
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`.
@@ -105,7 +107,7 @@ class InputGraphEquationCondition(InputEquationCondition):
in the :class:`pina.graph.Graph` object.
:param input: Input data.
:type input: torch.Tensor | Graph | torch_geometric.data.Data
:type input: torch.Tensor | Graph | Data
:raises ValueError: If the input data object does not contain at least
one LabelTensor.

View File

@@ -25,23 +25,26 @@ class InputTargetCondition(ConditionInterface):
the types of input and target data.
:param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph |
torch_geometric.data.Data | list[Graph] |
:type input: torch.Tensor | LabelTensor | Graph |
torch_geometric.data.Data | list[Graph] |
list[torch_geometric.data.Data] | tuple[Graph] |
tuple[torch_geometric.data.Data]
:param target: Target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph |
torch_geometric.data.Data | list[Graph] |
:type target: torch.Tensor | LabelTensor | Graph |
torch_geometric.data.Data | list[Graph] |
list[torch_geometric.data.Data] | tuple[Graph] |
tuple[torch_geometric.data.Data]
:return: Subclass of InputTargetCondition
:rtype: TensorInputTensorTargetCondition | \
TensorInputGraphTargetCondition | \
GraphInputTensorTargetCondition | \
GraphInputGraphTargetCondition
:rtype: pina.condition.input_target_condition.
TensorInputTensorTargetCondition |
pina.condition.input_target_condition.
TensorInputGraphTargetCondition |
pina.condition.input_target_condition.
GraphInputTensorTargetCondition |
pina.condition.input_target_condition.GraphInputGraphTargetCondition
:raises ValueError: If input and or target are not of type
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,
:class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
"""
if cls != InputTargetCondition: