committed by
Nicola Demo
parent
f0d68b34c7
commit
30f865d912
@@ -15,6 +15,12 @@ class DomainEquationCondition(ConditionInterface):
|
||||
self.domain = domain
|
||||
self.equation = equation
|
||||
|
||||
def residual(self, model):
|
||||
"""
|
||||
Compute the residual of the condition.
|
||||
"""
|
||||
self.batch_residual(model, self.domain, self.equation)
|
||||
|
||||
@staticmethod
|
||||
def batch_residual(model, input_pts, equation):
|
||||
"""
|
||||
@@ -22,7 +28,7 @@ class DomainEquationCondition(ConditionInterface):
|
||||
output points are provided as arguments.
|
||||
|
||||
:param torch.nn.Module model: The model to evaluate the condition.
|
||||
:param torch.Tensor input_points: The input points.
|
||||
:param torch.Tensor output_points: The output points.
|
||||
:param torch.Tensor input_pts: The input points.
|
||||
:param torch.Tensor equation: The output points.
|
||||
"""
|
||||
return equation.residual(model(input_pts))
|
||||
return equation.residual(input_pts, model(input_pts))
|
||||
Reference in New Issue
Block a user