add exhaustive doc for condition module (#629)

This commit is contained in:
Giovanni Canali
2025-09-11 15:47:06 +02:00
committed by GitHub
parent f3ccfd4598
commit a0015c3af6
6 changed files with 366 additions and 246 deletions

View File

@@ -8,24 +8,25 @@ from ..graph import Graph
class ConditionInterface(metaclass=ABCMeta):
"""
Abstract class which defines a common interface for all the conditions.
It defined a common interface for all the conditions.
Abstract base class for PINA conditions. All specific conditions must
inherit from this interface.
Refer to :class:`pina.condition.condition.Condition` for a thorough
description of all available conditions and how to instantiate them.
"""
def __init__(self):
"""
Initialize the ConditionInterface object.
Initialization of the :class:`ConditionInterface` class.
"""
self._problem = None
@property
def problem(self):
"""
Return the problem to which the condition is associated.
Return the problem associated with this condition.
:return: Problem to which the condition is associated.
:return: Problem associated with this condition.
:rtype: ~pina.problem.abstract_problem.AbstractProblem
"""
return self._problem
@@ -33,31 +34,32 @@ class ConditionInterface(metaclass=ABCMeta):
@problem.setter
def problem(self, value):
"""
Set the problem to which the condition is associated.
Set the problem associated with this condition.
:param pina.problem.abstract_problem.AbstractProblem value: Problem to
which the condition is associated
:param pina.problem.abstract_problem.AbstractProblem value: The problem
to associate with this condition
"""
self._problem = value
@staticmethod
def _check_graph_list_consistency(data_list):
"""
Check the consistency of the list of Data/Graph objects. It performs
the following checks:
Check the consistency of the list of Data | Graph objects.
The following checks are performed:
1. All elements in the list must be of the same type (either Data or
Graph).
2. All elements in the list must have the same keys.
3. The type of each tensor must be consistent across all elements in
the list.
4. If the tensor is a LabelTensor, the labels must be consistent across
all elements in the list.
- All elements in the list must be of the same type (either
:class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`).
:param data_list: List of Data/Graph objects to check
- All elements in the list must have the same keys.
- The data type of each tensor must be consistent across all elements.
- If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels
must also be consistent across all elements.
:param data_list: The list of Data | Graph objects to check.
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
:raises ValueError: If the input types are invalid.
:raises ValueError: If the input types are invalid.
:raises ValueError: If all elements in the list do not have the same
keys.
:raises ValueError: If the type of each tensor is not consistent across
@@ -65,51 +67,45 @@ class ConditionInterface(metaclass=ABCMeta):
:raises ValueError: If the labels of the LabelTensors are not consistent
across all elements in the list.
"""
# If the data is a Graph or Data object, return (do not need to check
# anything)
# If the data is a Graph or Data object, perform no checks
if isinstance(data_list, (Graph, Data)):
return
# check all elements in the list are of the same type
# Check all elements in the list are of the same type
if not all(isinstance(i, (Graph, Data)) for i in data_list):
raise ValueError(
"Invalid input types. "
"Please provide either Data or Graph objects."
"Invalid input. Please, provide either Data or Graph objects."
)
# Store the keys, data types and labels of the first element
data = data_list[0]
# Store the keys of the first element in the list
keys = sorted(list(data.keys()))
# Store the type of each tensor inside first element Data/Graph object
data_types = {name: tensor.__class__ for name, tensor in data.items()}
# Store the labels of each LabelTensor inside first element Data/Graph
# object
labels = {
name: tensor.labels
for name, tensor in data.items()
if isinstance(tensor, LabelTensor)
}
# Iterate over the list of Data/Graph objects
# Iterate over the list of Data | Graph objects
for data in data_list[1:]:
# Check if the keys of the current element are the same as the first
# element
# Check that all elements in the list have the same keys
if sorted(list(data.keys())) != keys:
raise ValueError(
"All elements in the list must have the same keys."
)
# Iterate over the tensors in the current element
for name, tensor in data.items():
# Check if the type of each tensor inside the current element
# is the same as the first element
# Check that the type of each tensor is consistent
if tensor.__class__ is not data_types[name]:
raise ValueError(
f"Data {name} must be a {data_types[name]}, got "
f"{tensor.__class__}"
)
# If the tensor is a LabelTensor, check if the labels are the
# same as the first element
# Check that the labels of each LabelTensor are consistent
if isinstance(tensor, LabelTensor):
if tensor.labels != labels[name]:
raise ValueError(
@@ -117,6 +113,13 @@ class ConditionInterface(metaclass=ABCMeta):
)
def __getattribute__(self, name):
"""
Get an attribute from the object.
:param str name: The name of the attribute to get.
:return: The requested attribute.
:rtype: Any
"""
to_return = super().__getattribute__(name)
if isinstance(to_return, (Graph, Data)):
to_return = [to_return]