Additional fix in condition
This commit is contained in:
@@ -98,11 +98,11 @@ class Condition:
|
|||||||
- `input`: :class:`~pina.condition.data_condition.DataCondition`
|
- `input`: :class:`~pina.condition.data_condition.DataCondition`
|
||||||
- `input` and `conditional_variables`:
|
- `input` and `conditional_variables`:
|
||||||
:class:`~pina.condition.data_condition.DataCondition`
|
:class:`~pina.condition.data_condition.DataCondition`
|
||||||
|
|
||||||
:raises ValueError: No valid condition has been found.
|
|
||||||
:return: A new condition instance belonging to the proper class.
|
:return: A new condition instance belonging to the proper class.
|
||||||
:rtype: InputTargetCondition | DomainEquationCondition |
|
:rtype: InputTargetCondition | DomainEquationCondition |
|
||||||
InputEquationCondition | DataCondition
|
InputEquationCondition | DataCondition
|
||||||
|
|
||||||
|
:raises ValueError: No valid condition has been found.
|
||||||
"""
|
"""
|
||||||
if len(args) != 0:
|
if len(args) != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -11,9 +11,8 @@ from ..graph import Graph
|
|||||||
|
|
||||||
class DataCondition(ConditionInterface):
|
class DataCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
This condition must be used every time a Unsupervised Loss is needed in
|
Condition defined by input data and conditional variables. It can be used
|
||||||
the Solver. The `conditional_variable` can be passed as extra-input when
|
in unsupervised learning problems.
|
||||||
the model learns a conditional distribution.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["input", "conditional_variables"]
|
__slots__ = ["input", "conditional_variables"]
|
||||||
@@ -22,14 +21,14 @@ class DataCondition(ConditionInterface):
|
|||||||
|
|
||||||
def __new__(cls, input, conditional_variables=None):
|
def __new__(cls, input, conditional_variables=None):
|
||||||
"""
|
"""
|
||||||
Instantiate the appropriate subclass of DataCondition based on the
|
Instantiate the appropriate subclass of :class:`DataCondition` based on
|
||||||
types of input data.
|
the type of `input`.
|
||||||
|
|
||||||
:param input: Input data for the condition.
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor | LabelTensor | Graph |
|
:type input: torch.Tensor | LabelTensor | Graph |
|
||||||
Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
|
Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
|
||||||
:param conditional_variables: Conditional variables for the condition.
|
:param conditional_variables: Conditional variables for the condition.
|
||||||
:type conditional_variables: torch.Tensor | LabelTensor
|
:type conditional_variables: torch.Tensor | LabelTensor, optional
|
||||||
:return: Subclass of DataCondition.
|
:return: Subclass of DataCondition.
|
||||||
:rtype: pina.condition.data_condition.TensorDataCondition |
|
:rtype: pina.condition.data_condition.TensorDataCondition |
|
||||||
pina.condition.data_condition.GraphDataCondition
|
pina.condition.data_condition.GraphDataCondition
|
||||||
@@ -37,9 +36,8 @@ class DataCondition(ConditionInterface):
|
|||||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||||
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
|
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
|
||||||
or :class:`~torch_geometric.data.Data`.
|
or :class:`~torch_geometric.data.Data`.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if cls != DataCondition:
|
if cls != DataCondition:
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
if isinstance(input, (torch.Tensor, LabelTensor)):
|
if isinstance(input, (torch.Tensor, LabelTensor)):
|
||||||
@@ -69,8 +67,8 @@ class DataCondition(ConditionInterface):
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If either `input` is composed by a list of
|
If either `input` is composed by a list of
|
||||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
|
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`,
|
||||||
objects, all elements must have the same structure (keys and data
|
all elements must have the same structure (keys and data
|
||||||
types)
|
types)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -10,15 +10,16 @@ from ..equation.equation_interface import EquationInterface
|
|||||||
|
|
||||||
class DomainEquationCondition(ConditionInterface):
|
class DomainEquationCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition for domain/equation data. This condition must be used every
|
Condition defined by a domain and an equation. It can be used in Physics
|
||||||
time a Physics Informed Loss is needed in the Solver.
|
Informed problems. Before using this condition, make sure that input data
|
||||||
|
are correctly sampled from the domain.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["domain", "equation"]
|
__slots__ = ["domain", "equation"]
|
||||||
|
|
||||||
def __init__(self, domain, equation):
|
def __init__(self, domain, equation):
|
||||||
"""
|
"""
|
||||||
Initialize the object by storing the domain and equation.
|
Initialise the object by storing the domain and equation.
|
||||||
|
|
||||||
:param DomainInterface domain: Domain object containing the domain data.
|
:param DomainInterface domain: Domain object containing the domain data.
|
||||||
:param EquationInterface equation: Equation object containing the
|
:param EquationInterface equation: Equation object containing the
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from ..equation.equation_interface import EquationInterface
|
|||||||
|
|
||||||
class InputEquationCondition(ConditionInterface):
|
class InputEquationCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition for input/equation data. This condition must be used every
|
Condition defined by input data and an equation. This condition can be
|
||||||
time a Physics Informed Loss is needed in the Solver.
|
used in a Physics Informed problems.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["input", "equation"]
|
__slots__ = ["input", "equation"]
|
||||||
@@ -22,10 +22,10 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
|
|
||||||
def __new__(cls, input, equation):
|
def __new__(cls, input, equation):
|
||||||
"""
|
"""
|
||||||
Instantiate the appropriate subclass of InputEquationCondition based on
|
Instantiate the appropriate subclass of :class:`InputEquationCondition`
|
||||||
the type of input data.
|
based on the type of `input`.
|
||||||
|
|
||||||
:param input: Input data. It can be a LabelTensor or a Graph object.
|
:param input: Input data for the condition.
|
||||||
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
|
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
|
||||||
:param EquationInterface equation: Equation object containing the
|
:param EquationInterface equation: Equation object containing the
|
||||||
equation function.
|
equation function.
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from .condition_interface import ConditionInterface
|
|||||||
|
|
||||||
class InputTargetCondition(ConditionInterface):
|
class InputTargetCondition(ConditionInterface):
|
||||||
"""
|
"""
|
||||||
Condition for domain/equation data. This condition must be used every
|
Condition defined by input and target data. This condition can be used in
|
||||||
time a Physics Informed or a Supervised Loss is needed in the Solver.
|
both supervised learning and Physics-informed problems.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["input", "target"]
|
__slots__ = ["input", "target"]
|
||||||
@@ -25,15 +25,11 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
the types of input and target data.
|
the types of input and target data.
|
||||||
|
|
||||||
:param input: Input data for the condition.
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor | LabelTensor | Graph |
|
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||||
torch_geometric.data.Data | list[Graph] |
|
list[Data] | tuple[Graph] | tuple[Data]
|
||||||
list[torch_geometric.data.Data] | tuple[Graph] |
|
|
||||||
tuple[torch_geometric.data.Data]
|
|
||||||
:param target: Target data for the condition.
|
:param target: Target data for the condition.
|
||||||
:type target: torch.Tensor | LabelTensor | Graph |
|
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||||
torch_geometric.data.Data | list[Graph] |
|
list[Data] | tuple[Graph] | tuple[Data]
|
||||||
list[torch_geometric.data.Data] | tuple[Graph] |
|
|
||||||
tuple[torch_geometric.data.Data]
|
|
||||||
:return: Subclass of InputTargetCondition
|
:return: Subclass of InputTargetCondition
|
||||||
:rtype: pina.condition.input_target_condition.
|
:rtype: pina.condition.input_target_condition.
|
||||||
TensorInputTensorTargetCondition |
|
TensorInputTensorTargetCondition |
|
||||||
@@ -43,7 +39,7 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
GraphInputTensorTargetCondition |
|
GraphInputTensorTargetCondition |
|
||||||
pina.condition.input_target_condition.GraphInputGraphTargetCondition
|
pina.condition.input_target_condition.GraphInputGraphTargetCondition
|
||||||
|
|
||||||
:raises ValueError: If input and or target are not of type
|
:raises ValueError: If `input` and/or `target` are not of type
|
||||||
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
|
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
|
||||||
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
||||||
"""
|
"""
|
||||||
@@ -85,12 +81,11 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
|
|
||||||
def __init__(self, input, target):
|
def __init__(self, input, target):
|
||||||
"""
|
"""
|
||||||
Initialize the object storing the input and target data.
|
Initialize the object by storing the `input` and `target` data.
|
||||||
|
|
||||||
:param input: Input data for the condition.
|
:param input: Input data for the condition.
|
||||||
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||||
list[Data] | tuple[Graph] |
|
list[Data] | tuple[Graph] | tuple[Data]
|
||||||
tuple[Data]
|
|
||||||
:param target: Target data for the condition.
|
:param target: Target data for the condition.
|
||||||
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||||
list[Data] | tuple[Graph] | tuple[Data]
|
list[Data] | tuple[Graph] | tuple[Data]
|
||||||
@@ -122,15 +117,15 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
class TensorInputTensorTargetCondition(InputTargetCondition):
|
class TensorInputTensorTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||||
:class:`~pina.label_tensor.LabelTensor` input and target data.
|
:class:`~pina.label_tensor.LabelTensor` `input` and `target` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||||
:class:`~pina.label_tensor.LabelTensor` input and
|
:class:`~pina.label_tensor.LabelTensor` `input` and
|
||||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
|
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target`
|
||||||
data.
|
data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -138,13 +133,13 @@ class TensorInputGraphTargetCondition(InputTargetCondition):
|
|||||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
|
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
|
||||||
:class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or
|
:class:`~torch_geometric.data.Data` `input` and :class:`torch.Tensor` or
|
||||||
:class:`~pina.label_tensor.LabelTensor` target data.
|
:class:`~pina.label_tensor.LabelTensor` `target` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
|
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
|
||||||
:class:`~torch_geometric.data.Data` input and target data.
|
:class:`~torch_geometric.data.Data` `input` and `target` data.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user