Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -1,34 +1,84 @@
"""
Module that defines the ConditionInterface class.
"""
from abc import ABCMeta
from torch_geometric.data import Data
from ..label_tensor import LabelTensor
from ..graph import Graph
class ConditionInterface(metaclass=ABCMeta):
"""
Abstract class which defines a common interface for all the conditions.
"""
condition_types = ["physics", "supervised", "unsupervised"]
def __init__(self, *args, **kwargs):
self._condition_type = None
def __init__(self):
self._problem = None
@property
def problem(self):
"""
Return the problem to which the condition is associated.
:return: Problem to which the condition is associated
:rtype: pina.problem.AbstractProblem
"""
return self._problem
@problem.setter
def problem(self, value):
self._problem = value
@property
def condition_type(self):
return self._condition_type
@staticmethod
def _check_graph_list_consistency(data_list):
@condition_type.setter
def condition_type(self, values):
if not isinstance(values, (list, tuple)):
values = [values]
for value in values:
if value not in ConditionInterface.condition_types:
# If the data is a Graph or Data object, return (do not need to check
# anything)
if isinstance(data_list, (Graph, Data)):
return
# check all elements in the list are of the same type
if not all(isinstance(i, (Graph, Data)) for i in data_list):
raise ValueError(
"Invalid input types. "
"Please provide either Data or Graph objects."
)
data = data_list[0]
# Store the keys of the first element in the list
keys = sorted(list(data.keys()))
# Store the type of each tensor inside first element Data/Graph object
data_types = {name: tensor.__class__ for name, tensor in data.items()}
# Store the labels of each LabelTensor inside first element Data/Graph
# object
labels = {
name: tensor.labels
for name, tensor in data.items()
if isinstance(tensor, LabelTensor)
}
# Iterate over the list of Data/Graph objects
for data in data_list[1:]:
# Check if the keys of the current element are the same as the first
# element
if sorted(list(data.keys())) != keys:
raise ValueError(
"Unavailable type of condition, expected one of"
f" {ConditionInterface.condition_types}."
"All elements in the list must have the same keys."
)
self._condition_type = values
for name, tensor in data.items():
# Check if the type of each tensor inside the current element
# is the same as the first element
if tensor.__class__ is not data_types[name]:
raise ValueError(
f"Data {name} must be a {data_types[name]}, got "
f"{tensor.__class__}"
)
# If the tensor is a LabelTensor, check if the labels are the
# same as the first element
if isinstance(tensor, LabelTensor):
if tensor.labels != labels[name]:
raise ValueError(
"LabelTensor must have the same labels"
)