34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
from .condition_interface import ConditionInterface
|
|
|
|
class DomainEquationCondition(ConditionInterface):
|
|
"""
|
|
Condition for input/output data.
|
|
"""
|
|
|
|
__slots__ = ["domain", "equation"]
|
|
|
|
def __init__(self, domain, equation):
|
|
"""
|
|
Constructor for the `InputOutputCondition` class.
|
|
"""
|
|
super().__init__()
|
|
self.domain = domain
|
|
self.equation = equation
|
|
|
|
def residual(self, model):
|
|
"""
|
|
Compute the residual of the condition.
|
|
"""
|
|
self.batch_residual(model, self.domain, self.equation)
|
|
|
|
@staticmethod
|
|
def batch_residual(model, input_pts, equation):
|
|
"""
|
|
Compute the residual of the condition for a single batch. Input and
|
|
output points are provided as arguments.
|
|
|
|
:param torch.nn.Module model: The model to evaluate the condition.
|
|
:param torch.Tensor input_pts: The input points.
|
|
:param torch.Tensor equation: The output points.
|
|
"""
|
|
return equation.residual(input_pts, model(input_pts)) |