Additional fix in condition
This commit is contained in:
@@ -11,8 +11,8 @@ from .condition_interface import ConditionInterface
|
||||
|
||||
class InputTargetCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for domain/equation data. This condition must be used every
|
||||
time a Physics Informed or a Supervised Loss is needed in the Solver.
|
||||
Condition defined by input and target data. This condition can be used in
|
||||
both supervised learning and Physics-informed problems.
|
||||
"""
|
||||
|
||||
__slots__ = ["input", "target"]
|
||||
@@ -25,15 +25,11 @@ class InputTargetCondition(ConditionInterface):
|
||||
the types of input and target data.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | LabelTensor | Graph |
|
||||
torch_geometric.data.Data | list[Graph] |
|
||||
list[torch_geometric.data.Data] | tuple[Graph] |
|
||||
tuple[torch_geometric.data.Data]
|
||||
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||
list[Data] | tuple[Graph] | tuple[Data]
|
||||
:param target: Target data for the condition.
|
||||
:type target: torch.Tensor | LabelTensor | Graph |
|
||||
torch_geometric.data.Data | list[Graph] |
|
||||
list[torch_geometric.data.Data] | tuple[Graph] |
|
||||
tuple[torch_geometric.data.Data]
|
||||
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||
list[Data] | tuple[Graph] | tuple[Data]
|
||||
:return: Subclass of InputTargetCondition
|
||||
:rtype: pina.condition.input_target_condition.
|
||||
TensorInputTensorTargetCondition |
|
||||
@@ -43,7 +39,7 @@ class InputTargetCondition(ConditionInterface):
|
||||
GraphInputTensorTargetCondition |
|
||||
pina.condition.input_target_condition.GraphInputGraphTargetCondition
|
||||
|
||||
:raises ValueError: If input and or target are not of type
|
||||
:raises ValueError: If `input` and/or `target` are not of type
|
||||
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
|
||||
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
||||
"""
|
||||
@@ -85,12 +81,11 @@ class InputTargetCondition(ConditionInterface):
|
||||
|
||||
def __init__(self, input, target):
|
||||
"""
|
||||
Initialize the object storing the input and target data.
|
||||
Initialize the object by storing the `input` and `target` data.
|
||||
|
||||
:param input: Input data for the condition.
|
||||
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||
list[Data] | tuple[Graph] |
|
||||
tuple[Data]
|
||||
list[Data] | tuple[Graph] | tuple[Data]
|
||||
:param target: Target data for the condition.
|
||||
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
|
||||
list[Data] | tuple[Graph] | tuple[Data]
|
||||
@@ -122,15 +117,15 @@ class InputTargetCondition(ConditionInterface):
|
||||
class TensorInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor` input and target data.
|
||||
:class:`~pina.label_tensor.LabelTensor` `input` and `target` data.
|
||||
"""
|
||||
|
||||
|
||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor` input and
|
||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
|
||||
:class:`~pina.label_tensor.LabelTensor` `input` and
|
||||
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target`
|
||||
data.
|
||||
"""
|
||||
|
||||
@@ -138,13 +133,13 @@ class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
|
||||
:class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor` target data.
|
||||
:class:`~torch_geometric.data.Data` `input` and :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor` `target` data.
|
||||
"""
|
||||
|
||||
|
||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||
"""
|
||||
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
|
||||
:class:`~torch_geometric.data.Data` input and target data.
|
||||
:class:`~torch_geometric.data.Data` `input` and `target` data.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user