This commit is contained in:
Nicola Demo
2024-08-05 17:34:34 +02:00
parent 686b557144
commit 5245a0b68c
19 changed files with 483 additions and 173 deletions

View File

@@ -0,0 +1,28 @@
from .condition_interface import ConditionInterface
class DomainEquationCondition(ConditionInterface):
"""
Condition for input/output data.
"""
__slots__ = ["domain", "equation"]
def __init__(self, domain, equation):
"""
Constructor for the `InputOutputCondition` class.
"""
super().__init__()
self.domain = domain
self.equation = equation
@staticmethod
def batch_residual(model, input_pts, equation):
"""
Compute the residual of the condition for a single batch. Input and
output points are provided as arguments.
:param torch.nn.Module model: The model to evaluate the condition.
:param torch.Tensor input_points: The input points.
:param torch.Tensor output_points: The output points.
"""
return equation.residual(model(input_pts))