new conditions
This commit is contained in:
committed by
Nicola Demo
parent
a888141707
commit
fd16fcf9b4
@@ -1,34 +1,43 @@
|
||||
import torch
|
||||
|
||||
from .condition_interface import ConditionInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..graph import Graph
|
||||
from ..utils import check_consistency
|
||||
from ..domain import DomainInterface
|
||||
from ..equation.equation_interface import EquationInterface
|
||||
|
||||
class DomainEquationCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for input/output data.
|
||||
Condition for domain/equation data. This condition must be used every
|
||||
time a Physics Informed Loss is needed in the Solver.
|
||||
"""
|
||||
|
||||
__slots__ = ["domain", "equation"]
|
||||
|
||||
def __init__(self, domain, equation):
|
||||
"""
|
||||
Constructor for the `InputOutputCondition` class.
|
||||
TODO
|
||||
"""
|
||||
super().__init__()
|
||||
self.domain = domain
|
||||
self.equation = equation
|
||||
self.condition_type = 'physics'
|
||||
|
||||
def residual(self, model):
|
||||
"""
|
||||
Compute the residual of the condition.
|
||||
"""
|
||||
self.batch_residual(model, self.domain, self.equation)
|
||||
@property
|
||||
def domain(self):
|
||||
return self._domain
|
||||
|
||||
@domain.setter
|
||||
def domain(self, value):
|
||||
check_consistency(value, (DomainInterface))
|
||||
self._domain = value
|
||||
|
||||
@staticmethod
|
||||
def batch_residual(model, input_pts, equation):
|
||||
"""
|
||||
Compute the residual of the condition for a single batch. Input and
|
||||
output points are provided as arguments.
|
||||
|
||||
:param torch.nn.Module model: The model to evaluate the condition.
|
||||
:param torch.Tensor input_pts: The input points.
|
||||
:param torch.Tensor equation: The output points.
|
||||
"""
|
||||
return equation.residual(input_pts, model(input_pts))
|
||||
@property
|
||||
def equation(self):
|
||||
return self._equation
|
||||
|
||||
@equation.setter
|
||||
def equation(self, value):
|
||||
check_consistency(value, (EquationInterface))
|
||||
self._equation = value
|
||||
Reference in New Issue
Block a user