minor changes/ trainer update

This commit is contained in:
Dario Coscia
2024-10-10 19:24:46 +02:00
committed by Nicola Demo
parent 7528f6ef74
commit b9753c34b2
8 changed files with 69 additions and 46 deletions

View File

@@ -13,20 +13,20 @@ class DataConditionInterface(ConditionInterface):
distribution
"""
__slots__ = ["data", "conditionalvariable"]
__slots__ = ["input_points", "conditional_variables"]
def __init__(self, data, conditionalvariable=None):
def __init__(self, input_points, conditional_variables=None):
"""
TODO
"""
super().__init__()
self.data = data
self.conditionalvariable = conditionalvariable
self.input_points = input_points
self.conditional_variables = conditional_variables
self.condition_type = 'unsupervised'
def __setattr__(self, key, value):
if (key == 'data') or (key == 'conditionalvariable'):
if (key == 'input_points') or (key == 'conditional_variables'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
elif key in ('problem', 'condition_type'):
super().__setattr__(key, value)

View File

@@ -29,5 +29,5 @@ class DomainEquationCondition(ConditionInterface):
elif key == 'equation':
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
elif key in ('problem', 'condition_type'):
super().__setattr__(key, value)

View File

@@ -30,5 +30,5 @@ class InputPointsEquationCondition(ConditionInterface):
elif key == 'equation':
check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
elif key in ('problem', 'condition_type'):
super().__setattr__(key, value)

View File

@@ -27,5 +27,5 @@ class InputOutputPointsCondition(ConditionInterface):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
elif key in ('problem', 'condition_type'):
super().__setattr__(key, value)