Add check conditions-solver consistency

This commit is contained in:
FilippoOlivo
2025-01-16 19:42:46 +01:00
committed by Nicola Demo
parent f2340cd4ee
commit a6f0336d06
3 changed files with 20 additions and 22 deletions

View File

@@ -8,6 +8,8 @@ from ...utils import check_consistency
from ...loss.loss_interface import LossInterface from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem from ...problem import InverseProblem
from ...optim import TorchOptimizer, TorchScheduler from ...optim import TorchOptimizer, TorchScheduler
from ...condition import InputOutputPointsCondition, \
InputPointsEquationCondition, DomainEquationCondition
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -24,6 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from to the user to choose which problem the implemented solver inheriting from
this class is suitable for. this class is suitable for.
""" """
accepted_conditions_types = (InputOutputPointsCondition,
InputPointsEquationCondition, DomainEquationCondition)
def __init__( def __init__(
self, self,
@@ -97,7 +101,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self._optimizer = self._pina_optimizers[0] self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0] self._scheduler = self._pina_schedulers[0]
def training_step(self, batch): def training_step(self, batch):
""" """
The Physics Informed Solver Training Step. This function takes care The Physics Informed Solver Training Step. This function takes care
@@ -117,14 +120,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
if 'output_points' in points: if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points'] input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) loss_ = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor)) condition_loss.append(loss_.as_subclass(torch.Tensor))
else: else:
input_pts = points['input_points'] input_pts = points['input_points']
condition = self.problem.conditions[condition_name] condition = self.problem.conditions[condition_name]
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation) loss_ = self.loss_phys(
input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor)) condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor)) condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed) # clamp unknown parameters in InverseProblem (if needed)
@@ -144,14 +149,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
for condition_name, points in batch: for condition_name, points in batch:
if 'output_points' in points: if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points'] input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) loss_ = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor)) condition_loss.append(loss_.as_subclass(torch.Tensor))
else: else:
input_pts = points['input_points'] input_pts = points['input_points']
condition = self.problem.conditions[condition_name] condition = self.problem.conditions[condition_name]
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation) loss_ = self.loss_phys(
input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor)) condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor)) condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed) # clamp unknown parameters in InverseProblem (if needed)

View File

@@ -10,7 +10,6 @@ import torch
import sys import sys
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
""" """
Solver base class. This class inherits is a wrapper of Solver base class. This class inherits is a wrapper of
@@ -133,23 +132,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return super().on_train_start() return super().on_train_start()
def _check_solver_consistency(self, problem):
pass
#TODO : Implement this method for the conditions
'''
for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)):
raise ValueError(
f'{self.__name__} dose not support condition '
f'{condition.condition_type}')
'''
@staticmethod @staticmethod
def get_batch_size(batch): def get_batch_size(batch):
# Assuming batch is your custom Batch object # Assuming batch is your custom Batch object
batch_size = 0 batch_size = 0
for data in batch: for data in batch:
batch_size += len(data[1]['input_points']) batch_size += len(data[1]['input_points'])
return batch_size return batch_size
def _check_solver_consistency(self, problem):
for condition in problem.conditions.values():
check_consistency(condition, self.accepted_conditions_types)

View File

@@ -7,6 +7,7 @@ from .solver import SolverInterface
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from ..utils import check_consistency from ..utils import check_consistency
from ..loss.loss_interface import LossInterface from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface): class SupervisedSolver(SolverInterface):
@@ -37,7 +38,8 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions. multiple (discretised) input functions.
""" """
__name__ = 'SupervisedSolver'
accepted_conditions_types = InputOutputPointsCondition
def __init__(self, def __init__(self,
problem, problem,
@@ -143,7 +145,6 @@ class SupervisedSolver(SolverInterface):
self.log('val_loss', loss, prog_bar=True, logger=True, self.log('val_loss', loss, prog_bar=True, logger=True,
batch_size=self.get_batch_size(batch), sync_dist=True) batch_size=self.get_batch_size(batch), sync_dist=True)
def test_step(self, batch, batch_idx) -> STEP_OUTPUT: def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
""" """
Solver test step. Solver test step.