Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -1,34 +1,84 @@
|
||||
"""
|
||||
Module that defines the ConditionInterface class.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta
|
||||
from torch_geometric.data import Data
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..graph import Graph
|
||||
|
||||
|
||||
class ConditionInterface(metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract class which defines a common interface for all the conditions.
|
||||
"""
|
||||
|
||||
condition_types = ["physics", "supervised", "unsupervised"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._condition_type = None
|
||||
def __init__(self):
|
||||
self._problem = None
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
Return the problem to which the condition is associated.
|
||||
|
||||
:return: Problem to which the condition is associated
|
||||
:rtype: pina.problem.AbstractProblem
|
||||
"""
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, value):
|
||||
self._problem = value
|
||||
|
||||
@property
|
||||
def condition_type(self):
|
||||
return self._condition_type
|
||||
@staticmethod
|
||||
def _check_graph_list_consistency(data_list):
|
||||
|
||||
@condition_type.setter
|
||||
def condition_type(self, values):
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values]
|
||||
for value in values:
|
||||
if value not in ConditionInterface.condition_types:
|
||||
# If the data is a Graph or Data object, return (do not need to check
|
||||
# anything)
|
||||
if isinstance(data_list, (Graph, Data)):
|
||||
return
|
||||
|
||||
# check all elements in the list are of the same type
|
||||
if not all(isinstance(i, (Graph, Data)) for i in data_list):
|
||||
raise ValueError(
|
||||
"Invalid input types. "
|
||||
"Please provide either Data or Graph objects."
|
||||
)
|
||||
data = data_list[0]
|
||||
# Store the keys of the first element in the list
|
||||
keys = sorted(list(data.keys()))
|
||||
|
||||
# Store the type of each tensor inside first element Data/Graph object
|
||||
data_types = {name: tensor.__class__ for name, tensor in data.items()}
|
||||
|
||||
# Store the labels of each LabelTensor inside first element Data/Graph
|
||||
# object
|
||||
labels = {
|
||||
name: tensor.labels
|
||||
for name, tensor in data.items()
|
||||
if isinstance(tensor, LabelTensor)
|
||||
}
|
||||
|
||||
# Iterate over the list of Data/Graph objects
|
||||
for data in data_list[1:]:
|
||||
# Check if the keys of the current element are the same as the first
|
||||
# element
|
||||
if sorted(list(data.keys())) != keys:
|
||||
raise ValueError(
|
||||
"Unavailable type of condition, expected one of"
|
||||
f" {ConditionInterface.condition_types}."
|
||||
"All elements in the list must have the same keys."
|
||||
)
|
||||
self._condition_type = values
|
||||
for name, tensor in data.items():
|
||||
# Check if the type of each tensor inside the current element
|
||||
# is the same as the first element
|
||||
if tensor.__class__ is not data_types[name]:
|
||||
raise ValueError(
|
||||
f"Data {name} must be a {data_types[name]}, got "
|
||||
f"{tensor.__class__}"
|
||||
)
|
||||
# If the tensor is a LabelTensor, check if the labels are the
|
||||
# same as the first element
|
||||
if isinstance(tensor, LabelTensor):
|
||||
if tensor.labels != labels[name]:
|
||||
raise ValueError(
|
||||
"LabelTensor must have the same labels"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user