add exhaustive doc for condition module (#629)
This commit is contained in:
@@ -8,24 +8,25 @@ from ..graph import Graph
|
||||
|
||||
class ConditionInterface(metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract class which defines a common interface for all the conditions.
|
||||
It defined a common interface for all the conditions.
|
||||
Abstract base class for PINA conditions. All specific conditions must
|
||||
inherit from this interface.
|
||||
|
||||
Refer to :class:`pina.condition.condition.Condition` for a thorough
|
||||
description of all available conditions and how to instantiate them.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the ConditionInterface object.
|
||||
Initialization of the :class:`ConditionInterface` class.
|
||||
"""
|
||||
|
||||
self._problem = None
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
Return the problem to which the condition is associated.
|
||||
Return the problem associated with this condition.
|
||||
|
||||
:return: Problem to which the condition is associated.
|
||||
:return: Problem associated with this condition.
|
||||
:rtype: ~pina.problem.abstract_problem.AbstractProblem
|
||||
"""
|
||||
return self._problem
|
||||
@@ -33,31 +34,32 @@ class ConditionInterface(metaclass=ABCMeta):
|
||||
@problem.setter
|
||||
def problem(self, value):
|
||||
"""
|
||||
Set the problem to which the condition is associated.
|
||||
Set the problem associated with this condition.
|
||||
|
||||
:param pina.problem.abstract_problem.AbstractProblem value: Problem to
|
||||
which the condition is associated
|
||||
:param pina.problem.abstract_problem.AbstractProblem value: The problem
|
||||
to associate with this condition
|
||||
"""
|
||||
self._problem = value
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_list_consistency(data_list):
|
||||
"""
|
||||
Check the consistency of the list of Data/Graph objects. It performs
|
||||
the following checks:
|
||||
Check the consistency of the list of Data | Graph objects.
|
||||
The following checks are performed:
|
||||
|
||||
1. All elements in the list must be of the same type (either Data or
|
||||
Graph).
|
||||
2. All elements in the list must have the same keys.
|
||||
3. The type of each tensor must be consistent across all elements in
|
||||
the list.
|
||||
4. If the tensor is a LabelTensor, the labels must be consistent across
|
||||
all elements in the list.
|
||||
- All elements in the list must be of the same type (either
|
||||
:class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`).
|
||||
|
||||
:param data_list: List of Data/Graph objects to check
|
||||
- All elements in the list must have the same keys.
|
||||
|
||||
- The data type of each tensor must be consistent across all elements.
|
||||
|
||||
- If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels
|
||||
must also be consistent across all elements.
|
||||
|
||||
:param data_list: The list of Data | Graph objects to check.
|
||||
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
|
||||
|
||||
:raises ValueError: If the input types are invalid.
|
||||
:raises ValueError: If the input types are invalid.
|
||||
:raises ValueError: If all elements in the list do not have the same
|
||||
keys.
|
||||
:raises ValueError: If the type of each tensor is not consistent across
|
||||
@@ -65,51 +67,45 @@ class ConditionInterface(metaclass=ABCMeta):
|
||||
:raises ValueError: If the labels of the LabelTensors are not consistent
|
||||
across all elements in the list.
|
||||
"""
|
||||
|
||||
# If the data is a Graph or Data object, return (do not need to check
|
||||
# anything)
|
||||
# If the data is a Graph or Data object, perform no checks
|
||||
if isinstance(data_list, (Graph, Data)):
|
||||
return
|
||||
|
||||
# check all elements in the list are of the same type
|
||||
# Check all elements in the list are of the same type
|
||||
if not all(isinstance(i, (Graph, Data)) for i in data_list):
|
||||
raise ValueError(
|
||||
"Invalid input types. "
|
||||
"Please provide either Data or Graph objects."
|
||||
"Invalid input. Please, provide either Data or Graph objects."
|
||||
)
|
||||
|
||||
# Store the keys, data types and labels of the first element
|
||||
data = data_list[0]
|
||||
# Store the keys of the first element in the list
|
||||
keys = sorted(list(data.keys()))
|
||||
|
||||
# Store the type of each tensor inside first element Data/Graph object
|
||||
data_types = {name: tensor.__class__ for name, tensor in data.items()}
|
||||
|
||||
# Store the labels of each LabelTensor inside first element Data/Graph
|
||||
# object
|
||||
labels = {
|
||||
name: tensor.labels
|
||||
for name, tensor in data.items()
|
||||
if isinstance(tensor, LabelTensor)
|
||||
}
|
||||
|
||||
# Iterate over the list of Data/Graph objects
|
||||
# Iterate over the list of Data | Graph objects
|
||||
for data in data_list[1:]:
|
||||
# Check if the keys of the current element are the same as the first
|
||||
# element
|
||||
|
||||
# Check that all elements in the list have the same keys
|
||||
if sorted(list(data.keys())) != keys:
|
||||
raise ValueError(
|
||||
"All elements in the list must have the same keys."
|
||||
)
|
||||
|
||||
# Iterate over the tensors in the current element
|
||||
for name, tensor in data.items():
|
||||
# Check if the type of each tensor inside the current element
|
||||
# is the same as the first element
|
||||
# Check that the type of each tensor is consistent
|
||||
if tensor.__class__ is not data_types[name]:
|
||||
raise ValueError(
|
||||
f"Data {name} must be a {data_types[name]}, got "
|
||||
f"{tensor.__class__}"
|
||||
)
|
||||
# If the tensor is a LabelTensor, check if the labels are the
|
||||
# same as the first element
|
||||
|
||||
# Check that the labels of each LabelTensor are consistent
|
||||
if isinstance(tensor, LabelTensor):
|
||||
if tensor.labels != labels[name]:
|
||||
raise ValueError(
|
||||
@@ -117,6 +113,13 @@ class ConditionInterface(metaclass=ABCMeta):
|
||||
)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
"""
|
||||
Get an attribute from the object.
|
||||
|
||||
:param str name: The name of the attribute to get.
|
||||
:return: The requested attribute.
|
||||
:rtype: Any
|
||||
"""
|
||||
to_return = super().__getattribute__(name)
|
||||
if isinstance(to_return, (Graph, Data)):
|
||||
to_return = [to_return]
|
||||
|
||||
Reference in New Issue
Block a user