Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -1,12 +1,15 @@
|
||||
import torch
|
||||
"""
|
||||
DataCondition class
|
||||
"""
|
||||
|
||||
from . import ConditionInterface
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from .condition_interface import ConditionInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..graph import Graph
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class DataConditionInterface(ConditionInterface):
|
||||
class DataCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for data. This condition must be used every
|
||||
time a Unsupervised Loss is needed in the Solver. The conditionalvariable
|
||||
@@ -14,19 +17,64 @@ class DataConditionInterface(ConditionInterface):
|
||||
distribution
|
||||
"""
|
||||
|
||||
__slots__ = ["input_points", "conditional_variables"]
|
||||
__slots__ = ["input", "conditional_variables"]
|
||||
_avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
|
||||
_avail_conditional_variables_cls = (torch.Tensor, LabelTensor)
|
||||
|
||||
def __init__(self, input_points, conditional_variables=None):
|
||||
def __new__(cls, input, conditional_variables=None):
|
||||
"""
|
||||
TODO : add docstring
|
||||
Instanciate the correct subclass of DataCondition by checking the type
|
||||
of the input data (input and conditional_variables).
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
data
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:param conditional_variables: torch.Tensor or LabelTensor containing
|
||||
the conditional variables
|
||||
:type conditional_variables: torch.Tensor or LabelTensor
|
||||
:return: DataCondition subclass
|
||||
:rtype: TensorDataCondition or GraphDataCondition
|
||||
"""
|
||||
if cls != DataCondition:
|
||||
return super().__new__(cls)
|
||||
if isinstance(input, (torch.Tensor, LabelTensor)):
|
||||
subclass = TensorDataCondition
|
||||
return subclass.__new__(subclass, input, conditional_variables)
|
||||
|
||||
if isinstance(input, (Graph, Data, list, tuple)):
|
||||
cls._check_graph_list_consistency(input)
|
||||
subclass = GraphDataCondition
|
||||
return subclass.__new__(subclass, input, conditional_variables)
|
||||
|
||||
raise ValueError(
|
||||
"Invalid input types. "
|
||||
"Please provide either Data or Graph objects."
|
||||
)
|
||||
|
||||
def __init__(self, input, conditional_variables=None):
|
||||
"""
|
||||
Initialize the DataCondition, storing the input and conditional
|
||||
variables (if any).
|
||||
|
||||
:param input: torch.Tensor or Graph/Data object containing the input
|
||||
data
|
||||
:type input: torch.Tensor or Graph or Data
|
||||
:param conditional_variables: torch.Tensor or LabelTensor containing
|
||||
the conditional variables
|
||||
:type conditional_variables: torch.Tensor or LabelTensor
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_points = input_points
|
||||
self.input = input
|
||||
self.conditional_variables = conditional_variables
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if (key == "input_points") or (key == "conditional_variables"):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
DataConditionInterface.__dict__[key].__set__(self, value)
|
||||
elif key in ("_problem", "_condition_type"):
|
||||
super().__setattr__(key, value)
|
||||
|
||||
class TensorDataCondition(DataCondition):
|
||||
"""
|
||||
DataCondition for torch.Tensor input data
|
||||
"""
|
||||
|
||||
|
||||
class GraphDataCondition(DataCondition):
|
||||
"""
|
||||
DataCondition for Graph/Data input data
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user