diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index f44ea1d..611a3f4 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -8,6 +8,8 @@ from ...utils import check_consistency from ...loss.loss_interface import LossInterface from ...problem import InverseProblem from ...optim import TorchOptimizer, TorchScheduler +from ...condition import InputOutputPointsCondition, \ + InputPointsEquationCondition, DomainEquationCondition torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 @@ -24,6 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): to the user to choose which problem the implemented solver inheriting from this class is suitable for. """ + accepted_conditions_types = (InputOutputPointsCondition, + InputPointsEquationCondition, DomainEquationCondition) def __init__( self, @@ -97,7 +101,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): self._optimizer = self._pina_optimizers[0] self._scheduler = self._pina_schedulers[0] - def training_step(self, batch): """ The Physics Informed Solver Training Step. This function takes care @@ -117,14 +120,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): if 'output_points' in points: input_pts, output_pts = points['input_points'], points['output_points'] - loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + loss_ = self.loss_data( + input_pts=input_pts, output_pts=output_pts) condition_loss.append(loss_.as_subclass(torch.Tensor)) else: input_pts = points['input_points'] condition = self.problem.conditions[condition_name] - loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation) + loss_ = self.loss_phys( + input_pts.requires_grad_(), condition.equation) condition_loss.append(loss_.as_subclass(torch.Tensor)) condition_loss.append(loss_.as_subclass(torch.Tensor)) # clamp unknown parameters in InverseProblem (if needed) @@ -144,14 +149,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): for condition_name, points in batch: if 'output_points' in points: input_pts, output_pts = points['input_points'], points['output_points'] - loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + loss_ = self.loss_data( + input_pts=input_pts, output_pts=output_pts) condition_loss.append(loss_.as_subclass(torch.Tensor)) else: input_pts = points['input_points'] condition = self.problem.conditions[condition_name] with torch.set_grad_enabled(True): - loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation) + loss_ = self.loss_phys( + input_pts.requires_grad_(), condition.equation) condition_loss.append(loss_.as_subclass(torch.Tensor)) condition_loss.append(loss_.as_subclass(torch.Tensor)) # clamp unknown parameters in InverseProblem (if needed) diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 8052b4b..408aee5 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -10,7 +10,6 @@ import torch import sys - class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ Solver base class. This class inherits is a wrapper of @@ -133,23 +132,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): return super().on_train_start() - def _check_solver_consistency(self, problem): - pass - #TODO : Implement this method for the conditions - ''' - - - for _, condition in problem.conditions.items(): - if not set(condition.condition_type).issubset( - set(self.accepted_condition_types)): - raise ValueError( - f'{self.__name__} dose not support condition ' - f'{condition.condition_type}') - ''' @staticmethod def get_batch_size(batch): # Assuming batch is your custom Batch object batch_size = 0 for data in batch: batch_size += len(data[1]['input_points']) - return batch_size \ No newline at end of file + return batch_size + + def _check_solver_consistency(self, problem): + for condition in problem.conditions.values(): + check_consistency(condition, self.accepted_conditions_types) diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index bce4b31..c7f5f66 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -7,6 +7,7 @@ from .solver import SolverInterface from ..label_tensor import LabelTensor from ..utils import check_consistency from ..loss.loss_interface import LossInterface +from ..condition import InputOutputPointsCondition class SupervisedSolver(SolverInterface): @@ -37,7 +38,8 @@ class SupervisedSolver(SolverInterface): we are seeking to approximate multiple (discretised) functions given multiple (discretised) input functions. """ - __name__ = 'SupervisedSolver' + + accepted_conditions_types = InputOutputPointsCondition def __init__(self, problem, @@ -143,7 +145,6 @@ class SupervisedSolver(SolverInterface): self.log('val_loss', loss, prog_bar=True, logger=True, batch_size=self.get_batch_size(batch), sync_dist=True) - def test_step(self, batch, batch_idx) -> STEP_OUTPUT: """ Solver test step.