Add check conditions-solver consistency
This commit is contained in:
committed by
Nicola Demo
parent
f2340cd4ee
commit
a6f0336d06
@@ -8,6 +8,8 @@ from ...utils import check_consistency
|
||||
from ...loss.loss_interface import LossInterface
|
||||
from ...problem import InverseProblem
|
||||
from ...optim import TorchOptimizer, TorchScheduler
|
||||
from ...condition import InputOutputPointsCondition, \
|
||||
InputPointsEquationCondition, DomainEquationCondition
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
@@ -24,6 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
to the user to choose which problem the implemented solver inheriting from
|
||||
this class is suitable for.
|
||||
"""
|
||||
accepted_conditions_types = (InputOutputPointsCondition,
|
||||
InputPointsEquationCondition, DomainEquationCondition)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -97,7 +101,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
|
||||
|
||||
def training_step(self, batch):
|
||||
"""
|
||||
The Physics Informed Solver Training Step. This function takes care
|
||||
@@ -117,14 +120,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
if 'output_points' in points:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
loss_ = self.loss_data(
|
||||
input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
else:
|
||||
input_pts = points['input_points']
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
||||
loss_ = self.loss_phys(
|
||||
input_pts.requires_grad_(), condition.equation)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
@@ -144,14 +149,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
for condition_name, points in batch:
|
||||
if 'output_points' in points:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
loss_ = self.loss_data(
|
||||
input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
else:
|
||||
input_pts = points['input_points']
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
with torch.set_grad_enabled(True):
|
||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
||||
loss_ = self.loss_phys(
|
||||
input_pts.requires_grad_(), condition.equation)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
|
||||
Reference in New Issue
Block a user