supervised working

This commit is contained in:
Nicola Demo
2024-08-08 16:19:52 +02:00
parent 5245a0b68c
commit 9d9c2aa23e
61 changed files with 375 additions and 262 deletions

View File

@@ -1,10 +1,10 @@
__all__ = [
'Condition',
'ConditionInterface',
'InputOutputCondition',
'InputEquationCondition'
'LocationEquationCondition',
'DomainOutputCondition',
'DomainEquationCondition'
]
from .condition_interface import ConditionInterface
from .input_output_condition import InputOutputCondition
from .domain_output_condition import DomainOutputCondition
from .domain_equation_condition import DomainEquationCondition

View File

@@ -1,9 +1,11 @@
""" Condition module. """
from ..label_tensor import LabelTensor
from ..geometry import Location
from ..domain import DomainInterface
from ..equation.equation import Equation
from . import DomainOutputCondition, DomainEquationCondition
def dummy(a):
"""Dummy function for testing purposes."""
@@ -51,14 +53,6 @@ class Condition:
"""
__slots__ = [
"input_points",
"output_points",
"location",
"equation",
"data_weight",
]
# def _dictvalue_isinstance(self, dict_, key_, class_):
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
# if key_ not in dict_.keys():
@@ -77,11 +71,17 @@ class Condition:
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
# )
from . import InputOutputCondition
def __new__(cls, *args, **kwargs):
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
return InputOutputCondition(**kwargs)
return DomainOutputCondition(
domain=kwargs["input_points"],
output_points=kwargs["output_points"]
)
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
return DomainOutputCondition(**kwargs)
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
return DomainEquationCondition(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

View File

@@ -4,6 +4,9 @@ from abc import ABCMeta, abstractmethod
class ConditionInterface(metaclass=ABCMeta):
def __init__(self) -> None:
self._problem = None
@abstractmethod
def residual(self, model):
"""

View File

@@ -1,26 +1,34 @@
from . import ConditionInterface
class InputOutputCondition(ConditionInterface):
class DomainOutputCondition(ConditionInterface):
"""
Condition for input/output data.
"""
__slots__ = ["input_points", "output_points"]
__slots__ = ["domain", "output_points"]
def __init__(self, input_points, output_points):
def __init__(self, domain, output_points):
"""
Constructor for the `InputOutputCondition` class.
"""
super().__init__()
self.input_points = input_points
print(self)
self.domain = domain
self.output_points = output_points
@property
def input_points(self):
"""
Get the input points of the condition.
"""
return self._problem.domains[self.domain]
def residual(self, model):
"""
Compute the residual of the condition.
"""
return self.batch_residual(model, self.input_points, self.output_points)
return self.batch_residual(model, self.domain, self.output_points)
@staticmethod
def batch_residual(model, input_points, output_points):

View File

@@ -1,7 +1,7 @@
from . import ConditionInterface
class InputOutputCondition(ConditionInterface):
class InputEquationCondition(ConditionInterface):
"""
Condition for input/output data.
"""