Fix bugs in 0.2 (#344)

* Fix some bugs
This commit is contained in:
FilippoOlivo
2024-09-12 18:12:59 +02:00
committed by Nicola Demo
parent f0d68b34c7
commit 30f865d912
11 changed files with 112 additions and 55 deletions

View File

@@ -84,14 +84,15 @@ class Condition:
return DomainEquationCondition(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
'''
if (
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
and sorted(kwargs.keys()) != sorted(["location", "equation"])
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
):
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
raise TypeError("`input_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
@@ -103,3 +104,4 @@ class Condition:
for key, value in kwargs.items():
setattr(self, key, value)
'''

View File

@@ -15,4 +15,7 @@ class ConditionInterface(metaclass=ABCMeta):
:param model: The model to evaluate the condition.
:return: The residual of the condition.
"""
pass
pass
def set_problem(self, problem):
self._problem = problem

View File

@@ -15,6 +15,12 @@ class DomainEquationCondition(ConditionInterface):
self.domain = domain
self.equation = equation
def residual(self, model):
"""
Compute the residual of the condition.
"""
self.batch_residual(model, self.domain, self.equation)
@staticmethod
def batch_residual(model, input_pts, equation):
"""
@@ -22,7 +28,7 @@ class DomainEquationCondition(ConditionInterface):
output points are provided as arguments.
:param torch.nn.Module model: The model to evaluate the condition.
:param torch.Tensor input_points: The input points.
:param torch.Tensor output_points: The output points.
:param torch.Tensor input_pts: The input points.
:param torch.Tensor equation: The output points.
"""
return equation.residual(model(input_pts))
return equation.residual(input_pts, model(input_pts))

View File

@@ -40,4 +40,5 @@ class DomainOutputCondition(ConditionInterface):
:param torch.Tensor input_points: The input points.
:param torch.Tensor output_points: The output points.
"""
return output_points - model(input_points)