committed by
Nicola Demo
parent
f0d68b34c7
commit
30f865d912
@@ -84,14 +84,15 @@ class Condition:
|
||||
return DomainEquationCondition(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
|
||||
# TODO: remove, not used anymore
|
||||
'''
|
||||
if (
|
||||
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
|
||||
and sorted(kwargs.keys()) != sorted(["location", "equation"])
|
||||
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
|
||||
):
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
|
||||
# TODO: remove, not used anymore
|
||||
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
|
||||
raise TypeError("`input_points` must be a torch.Tensor.")
|
||||
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
|
||||
@@ -103,3 +104,4 @@ class Condition:
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
'''
|
||||
@@ -15,4 +15,7 @@ class ConditionInterface(metaclass=ABCMeta):
|
||||
:param model: The model to evaluate the condition.
|
||||
:return: The residual of the condition.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
def set_problem(self, problem):
|
||||
self._problem = problem
|
||||
|
||||
@@ -15,6 +15,12 @@ class DomainEquationCondition(ConditionInterface):
|
||||
self.domain = domain
|
||||
self.equation = equation
|
||||
|
||||
def residual(self, model):
|
||||
"""
|
||||
Compute the residual of the condition.
|
||||
"""
|
||||
self.batch_residual(model, self.domain, self.equation)
|
||||
|
||||
@staticmethod
|
||||
def batch_residual(model, input_pts, equation):
|
||||
"""
|
||||
@@ -22,7 +28,7 @@ class DomainEquationCondition(ConditionInterface):
|
||||
output points are provided as arguments.
|
||||
|
||||
:param torch.nn.Module model: The model to evaluate the condition.
|
||||
:param torch.Tensor input_points: The input points.
|
||||
:param torch.Tensor output_points: The output points.
|
||||
:param torch.Tensor input_pts: The input points.
|
||||
:param torch.Tensor equation: The output points.
|
||||
"""
|
||||
return equation.residual(model(input_pts))
|
||||
return equation.residual(input_pts, model(input_pts))
|
||||
@@ -40,4 +40,5 @@ class DomainOutputCondition(ConditionInterface):
|
||||
:param torch.Tensor input_points: The input points.
|
||||
:param torch.Tensor output_points: The output points.
|
||||
"""
|
||||
|
||||
return output_points - model(input_points)
|
||||
Reference in New Issue
Block a user