* Adding Collector for handling data sampling/collection before dataset/dataloader

* Modify domain by adding sample_mode, variables as property
* Small change concatenate -> cat in lno/avno
* Create different factory classes for conditions
This commit is contained in:
Dario Coscia
2024-10-04 13:57:18 +02:00
committed by Nicola Demo
parent aef5a5d590
commit 1bd3f40f54
18 changed files with 225 additions and 277 deletions

View File

@@ -39,11 +39,10 @@ class Condition:
__slots__ = list(
set(
InputOutputPointsCondition.__slots__,
InputPointsEquationCondition.__slots__,
DomainEquationCondition.__slots__,
InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__
)
)
@@ -51,8 +50,8 @@ class Condition:
if len(args) != 0:
raise ValueError(
f"Condition takes only the following keyword '
'arguments: {Condition.__slots__}."
"Condition takes only the following keyword "
f"arguments: {Condition.__slots__}."
)
sorted_keys = sorted(kwargs.keys())

View File

@@ -1,18 +1,27 @@
from abc import ABCMeta
class ConditionInterface(metaclass=ABCMeta):
condition_types = ['physical', 'supervised', 'unsupervised']
def __init__(self):
def __init__(self, *args, **wargs):
self._condition_type = None
self._problem = None
@property
def problem(self):
return self._problem
@problem.setter
def problem(self, value):
self._problem = value
@property
def condition_type(self):
return self._condition_type
@condition_type.setattr
@condition_type.setter
def condition_type(self, values):
if not isinstance(values, (list, tuple)):
values = [values]

View File

@@ -24,21 +24,7 @@ class DataConditionInterface(ConditionInterface):
self.conditionalvariable = conditionalvariable
self.condition_type = 'unsupervised'
@property
def data(self):
return self._data
@data.setter
def data(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._data = value
@property
def conditionalvariable(self):
return self._conditionalvariable
@data.setter
def conditionalvariable(self, value):
if value is not None:
def __setattr__(self, key, value):
if (key == 'data') or (key == 'conditionalvariable'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._data = value
DataConditionInterface.__dict__[key].__set__(self, value)

View File

@@ -1,8 +1,6 @@
import torch
from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
from ..domain import DomainInterface
from ..equation.equation_interface import EquationInterface
@@ -24,20 +22,10 @@ class DomainEquationCondition(ConditionInterface):
self.equation = equation
self.condition_type = 'physics'
@property
def domain(self):
return self._domain
@domain.setter
def domain(self, value):
check_consistency(value, (DomainInterface))
self._domain = value
@property
def equation(self):
return self._equation
@equation.setter
def equation(self, value):
check_consistency(value, (EquationInterface))
self._equation = value
def __setattr__(self, key, value):
if key == 'domain':
check_consistency(value, (DomainInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)

View File

@@ -23,20 +23,10 @@ class InputPointsEquationCondition(ConditionInterface):
self.equation = equation
self.condition_type = 'physics'
@property
def input_points(self):
return self._input_points
@input_points.setter
def input_points(self, value):
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
self._input_points = value
@property
def equation(self):
return self._equation
@equation.setter
def equation(self, value):
check_consistency(value, (EquationInterface))
self._equation = value
def __setattr__(self, key, value):
if key == 'input_points':
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value)

View File

@@ -23,20 +23,7 @@ class InputOutputPointsCondition(ConditionInterface):
self.output_points = output_points
self.condition_type = ['supervised', 'physics']
@property
def input_points(self):
return self._input_points
@input_points.setter
def input_points(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._input_points = value
@property
def output_points(self):
return self._output_points
@output_points.setter
def output_points(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._output_points = value
def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
InputOutputPointsCondition.__dict__[key].__set__(self, value)