Files
PINA/pina/condition/domain_equation_condition.py
FilippoOlivo 30f865d912 Fix bugs in 0.2 (#344)
* Fix some bugs
2025-03-19 17:46:33 +01:00

34 lines
1.0 KiB
Python

from .condition_interface import ConditionInterface
class DomainEquationCondition(ConditionInterface):
"""
Condition for input/output data.
"""
__slots__ = ["domain", "equation"]
def __init__(self, domain, equation):
"""
Constructor for the `InputOutputCondition` class.
"""
super().__init__()
self.domain = domain
self.equation = equation
def residual(self, model):
"""
Compute the residual of the condition.
"""
self.batch_residual(model, self.domain, self.equation)
@staticmethod
def batch_residual(model, input_pts, equation):
"""
Compute the residual of the condition for a single batch. Input and
output points are provided as arguments.
:param torch.nn.Module model: The model to evaluate the condition.
:param torch.Tensor input_pts: The input points.
:param torch.Tensor equation: The output points.
"""
return equation.residual(input_pts, model(input_pts))