Add check conditions-solver consistency

This commit is contained in:
FilippoOlivo
2025-01-16 19:42:46 +01:00
committed by Nicola Demo
parent f2340cd4ee
commit a6f0336d06
3 changed files with 20 additions and 22 deletions

View File

@@ -8,6 +8,8 @@ from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...optim import TorchOptimizer, TorchScheduler
from ...condition import InputOutputPointsCondition, \
InputPointsEquationCondition, DomainEquationCondition
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -24,6 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
"""
accepted_conditions_types = (InputOutputPointsCondition,
InputPointsEquationCondition, DomainEquationCondition)
def __init__(
self,
@@ -97,7 +101,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0]
def training_step(self, batch):
"""
The Physics Informed Solver Training Step. This function takes care
@@ -117,14 +120,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss_ = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
loss_ = self.loss_phys(
input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)
@@ -144,14 +149,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
for condition_name, points in batch:
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss_ = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
with torch.set_grad_enabled(True):
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
loss_ = self.loss_phys(
input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)

View File

@@ -10,7 +10,6 @@ import torch
import sys
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver base class. This class inherits is a wrapper of
@@ -133,23 +132,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return super().on_train_start()
def _check_solver_consistency(self, problem):
pass
#TODO : Implement this method for the conditions
'''
for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)):
raise ValueError(
f'{self.__name__} dose not support condition '
f'{condition.condition_type}')
'''
@staticmethod
def get_batch_size(batch):
# Assuming batch is your custom Batch object
batch_size = 0
for data in batch:
batch_size += len(data[1]['input_points'])
return batch_size
return batch_size
def _check_solver_consistency(self, problem):
for condition in problem.conditions.values():
check_consistency(condition, self.accepted_conditions_types)

View File

@@ -7,6 +7,7 @@ from .solver import SolverInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface):
@@ -37,7 +38,8 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
"""
__name__ = 'SupervisedSolver'
accepted_conditions_types = InputOutputPointsCondition
def __init__(self,
problem,
@@ -143,7 +145,6 @@ class SupervisedSolver(SolverInterface):
self.log('val_loss', loss, prog_bar=True, logger=True,
batch_size=self.get_batch_size(batch), sync_dist=True)
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
"""
Solver test step.