Add check conditions-solver consistency

This commit is contained in:
FilippoOlivo
2025-01-16 19:42:46 +01:00
committed by Nicola Demo
parent f2340cd4ee
commit a6f0336d06
3 changed files with 20 additions and 22 deletions

View File

@@ -8,6 +8,8 @@ from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...optim import TorchOptimizer, TorchScheduler
from ...condition import InputOutputPointsCondition, \
InputPointsEquationCondition, DomainEquationCondition
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -24,6 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
"""
accepted_conditions_types = (InputOutputPointsCondition,
InputPointsEquationCondition, DomainEquationCondition)
def __init__(
self,
@@ -97,7 +101,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0]
def training_step(self, batch):
"""
The Physics Informed Solver Training Step. This function takes care
@@ -117,14 +120,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss_ = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
loss_ = self.loss_phys(
input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)
@@ -144,14 +149,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
for condition_name, points in batch:
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss_ = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
with torch.set_grad_enabled(True):
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
loss_ = self.loss_phys(
input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)