* Adding Collector for handling data sampling/collection before dataset/dataloader
* Modify domain by adding sample_mode, variables as property * Small change concatenate -> cat in lno/avno * Create different factory classes for conditions
This commit is contained in:
committed by
Nicola Demo
parent
aef5a5d590
commit
1bd3f40f54
@@ -39,11 +39,10 @@ class Condition:
|
||||
|
||||
__slots__ = list(
|
||||
set(
|
||||
InputOutputPointsCondition.__slots__,
|
||||
InputPointsEquationCondition.__slots__,
|
||||
DomainEquationCondition.__slots__,
|
||||
InputOutputPointsCondition.__slots__ +
|
||||
InputPointsEquationCondition.__slots__ +
|
||||
DomainEquationCondition.__slots__ +
|
||||
DataConditionInterface.__slots__
|
||||
|
||||
)
|
||||
)
|
||||
|
||||
@@ -51,8 +50,8 @@ class Condition:
|
||||
|
||||
if len(args) != 0:
|
||||
raise ValueError(
|
||||
f"Condition takes only the following keyword '
|
||||
'arguments: {Condition.__slots__}."
|
||||
"Condition takes only the following keyword "
|
||||
f"arguments: {Condition.__slots__}."
|
||||
)
|
||||
|
||||
sorted_keys = sorted(kwargs.keys())
|
||||
|
||||
@@ -1,18 +1,27 @@
|
||||
|
||||
from abc import ABCMeta
|
||||
|
||||
|
||||
class ConditionInterface(metaclass=ABCMeta):
|
||||
|
||||
condition_types = ['physical', 'supervised', 'unsupervised']
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self, *args, **wargs):
|
||||
self._condition_type = None
|
||||
self._problem = None
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, value):
|
||||
self._problem = value
|
||||
|
||||
@property
|
||||
def condition_type(self):
|
||||
return self._condition_type
|
||||
|
||||
@condition_type.setattr
|
||||
@condition_type.setter
|
||||
def condition_type(self, values):
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values]
|
||||
|
||||
@@ -24,21 +24,7 @@ class DataConditionInterface(ConditionInterface):
|
||||
self.conditionalvariable = conditionalvariable
|
||||
self.condition_type = 'unsupervised'
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@data.setter
|
||||
def data(self, value):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
self._data = value
|
||||
|
||||
@property
|
||||
def conditionalvariable(self):
|
||||
return self._conditionalvariable
|
||||
|
||||
@data.setter
|
||||
def conditionalvariable(self, value):
|
||||
if value is not None:
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'data') or (key == 'conditionalvariable'):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
self._data = value
|
||||
DataConditionInterface.__dict__[key].__set__(self, value)
|
||||
@@ -1,8 +1,6 @@
|
||||
import torch
|
||||
|
||||
from .condition_interface import ConditionInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..graph import Graph
|
||||
from ..utils import check_consistency
|
||||
from ..domain import DomainInterface
|
||||
from ..equation.equation_interface import EquationInterface
|
||||
@@ -24,20 +22,10 @@ class DomainEquationCondition(ConditionInterface):
|
||||
self.equation = equation
|
||||
self.condition_type = 'physics'
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
return self._domain
|
||||
|
||||
@domain.setter
|
||||
def domain(self, value):
|
||||
check_consistency(value, (DomainInterface))
|
||||
self._domain = value
|
||||
|
||||
@property
|
||||
def equation(self):
|
||||
return self._equation
|
||||
|
||||
@equation.setter
|
||||
def equation(self, value):
|
||||
check_consistency(value, (EquationInterface))
|
||||
self._equation = value
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'domain':
|
||||
check_consistency(value, (DomainInterface))
|
||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key == 'equation':
|
||||
check_consistency(value, (EquationInterface))
|
||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||
@@ -23,20 +23,10 @@ class InputPointsEquationCondition(ConditionInterface):
|
||||
self.equation = equation
|
||||
self.condition_type = 'physics'
|
||||
|
||||
@property
|
||||
def input_points(self):
|
||||
return self._input_points
|
||||
|
||||
@input_points.setter
|
||||
def input_points(self, value):
|
||||
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
|
||||
self._input_points = value
|
||||
|
||||
@property
|
||||
def equation(self):
|
||||
return self._equation
|
||||
|
||||
@equation.setter
|
||||
def equation(self, value):
|
||||
check_consistency(value, (EquationInterface))
|
||||
self._equation = value
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'input_points':
|
||||
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
|
||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key == 'equation':
|
||||
check_consistency(value, (EquationInterface))
|
||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||
@@ -23,20 +23,7 @@ class InputOutputPointsCondition(ConditionInterface):
|
||||
self.output_points = output_points
|
||||
self.condition_type = ['supervised', 'physics']
|
||||
|
||||
@property
|
||||
def input_points(self):
|
||||
return self._input_points
|
||||
|
||||
@input_points.setter
|
||||
def input_points(self, value):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
self._input_points = value
|
||||
|
||||
@property
|
||||
def output_points(self):
|
||||
return self._output_points
|
||||
|
||||
@output_points.setter
|
||||
def output_points(self, value):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
self._output_points = value
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'input_points') or (key == 'output_points'):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
InputOutputPointsCondition.__dict__[key].__set__(self, value)
|
||||
Reference in New Issue
Block a user