supervised working
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
__all__ = [
|
||||
'Condition',
|
||||
'ConditionInterface',
|
||||
'InputOutputCondition',
|
||||
'InputEquationCondition'
|
||||
'LocationEquationCondition',
|
||||
'DomainOutputCondition',
|
||||
'DomainEquationCondition'
|
||||
]
|
||||
|
||||
from .condition_interface import ConditionInterface
|
||||
from .input_output_condition import InputOutputCondition
|
||||
from .domain_output_condition import DomainOutputCondition
|
||||
from .domain_equation_condition import DomainEquationCondition
|
||||
@@ -1,9 +1,11 @@
|
||||
""" Condition module. """
|
||||
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..geometry import Location
|
||||
from ..domain import DomainInterface
|
||||
from ..equation.equation import Equation
|
||||
|
||||
from . import DomainOutputCondition, DomainEquationCondition
|
||||
|
||||
|
||||
def dummy(a):
|
||||
"""Dummy function for testing purposes."""
|
||||
@@ -51,14 +53,6 @@ class Condition:
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
"input_points",
|
||||
"output_points",
|
||||
"location",
|
||||
"equation",
|
||||
"data_weight",
|
||||
]
|
||||
|
||||
# def _dictvalue_isinstance(self, dict_, key_, class_):
|
||||
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
||||
# if key_ not in dict_.keys():
|
||||
@@ -77,11 +71,17 @@ class Condition:
|
||||
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
|
||||
# )
|
||||
|
||||
from . import InputOutputCondition
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
||||
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
|
||||
return InputOutputCondition(**kwargs)
|
||||
return DomainOutputCondition(
|
||||
domain=kwargs["input_points"],
|
||||
output_points=kwargs["output_points"]
|
||||
)
|
||||
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
|
||||
return DomainOutputCondition(**kwargs)
|
||||
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
|
||||
return DomainEquationCondition(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
class ConditionInterface(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._problem = None
|
||||
|
||||
@abstractmethod
|
||||
def residual(self, model):
|
||||
"""
|
||||
|
||||
@@ -1,26 +1,34 @@
|
||||
|
||||
from . import ConditionInterface
|
||||
|
||||
class InputOutputCondition(ConditionInterface):
|
||||
class DomainOutputCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for input/output data.
|
||||
"""
|
||||
|
||||
__slots__ = ["input_points", "output_points"]
|
||||
__slots__ = ["domain", "output_points"]
|
||||
|
||||
def __init__(self, input_points, output_points):
|
||||
def __init__(self, domain, output_points):
|
||||
"""
|
||||
Constructor for the `InputOutputCondition` class.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_points = input_points
|
||||
print(self)
|
||||
self.domain = domain
|
||||
self.output_points = output_points
|
||||
|
||||
@property
|
||||
def input_points(self):
|
||||
"""
|
||||
Get the input points of the condition.
|
||||
"""
|
||||
return self._problem.domains[self.domain]
|
||||
|
||||
def residual(self, model):
|
||||
"""
|
||||
Compute the residual of the condition.
|
||||
"""
|
||||
return self.batch_residual(model, self.input_points, self.output_points)
|
||||
return self.batch_residual(model, self.domain, self.output_points)
|
||||
|
||||
@staticmethod
|
||||
def batch_residual(model, input_points, output_points):
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
from . import ConditionInterface
|
||||
|
||||
class InputOutputCondition(ConditionInterface):
|
||||
class InputEquationCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for input/output data.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user