new conditions

This commit is contained in:
Dario Coscia
2024-10-03 21:33:37 +02:00
committed by Nicola Demo
parent a888141707
commit fd16fcf9b4
8 changed files with 210 additions and 171 deletions

View File

@@ -1,34 +1,43 @@
import torch
from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
from ..domain import DomainInterface
from ..equation.equation_interface import EquationInterface
class DomainEquationCondition(ConditionInterface):
"""
Condition for input/output data.
Condition for domain/equation data. This condition must be used every
time a Physics Informed Loss is needed in the Solver.
"""
__slots__ = ["domain", "equation"]
def __init__(self, domain, equation):
"""
Constructor for the `InputOutputCondition` class.
TODO
"""
super().__init__()
self.domain = domain
self.equation = equation
self.condition_type = 'physics'
def residual(self, model):
"""
Compute the residual of the condition.
"""
self.batch_residual(model, self.domain, self.equation)
@property
def domain(self):
return self._domain
@domain.setter
def domain(self, value):
check_consistency(value, (DomainInterface))
self._domain = value
@staticmethod
def batch_residual(model, input_pts, equation):
"""
Compute the residual of the condition for a single batch. Input and
output points are provided as arguments.
:param torch.nn.Module model: The model to evaluate the condition.
:param torch.Tensor input_pts: The input points.
:param torch.Tensor equation: The output points.
"""
return equation.residual(input_pts, model(input_pts))
@property
def equation(self):
return self._equation
@equation.setter
def equation(self, value):
check_consistency(value, (EquationInterface))
self._equation = value