From e1863d93188288ba235838bf0e40839f6841fb74 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Mon, 14 Apr 2025 09:45:01 +0200 Subject: [PATCH] Update PINNInterface Inheritance (#542) --- .../physics_informed_solver/pinn_interface.py | 63 ++++--------------- pina/solver/solver.py | 34 +++++++++- 2 files changed, 46 insertions(+), 51 deletions(-) diff --git a/pina/solver/physics_informed_solver/pinn_interface.py b/pina/solver/physics_informed_solver/pinn_interface.py index c53e123..976f6ce 100644 --- a/pina/solver/physics_informed_solver/pinn_interface.py +++ b/pina/solver/physics_informed_solver/pinn_interface.py @@ -2,12 +2,8 @@ from abc import ABCMeta, abstractmethod import torch -from torch.nn.modules.loss import _Loss -from ..solver import SolverInterface -from ...utils import check_consistency -from ...loss.loss_interface import LossInterface -from ...problem import InverseProblem +from ..supervised_solver import SupervisedSolverInterface from ...condition import ( InputTargetCondition, InputEquationCondition, @@ -15,7 +11,7 @@ from ...condition import ( ) -class PINNInterface(SolverInterface, metaclass=ABCMeta): +class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta): """ Base class for Physics-Informed Neural Network (PINN) solvers, implementing the :class:`~pina.solver.solver.SolverInterface` class. @@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): DomainEquationCondition, ) - def __init__(self, problem, loss=None, **kwargs): + def __init__(self, **kwargs): """ Initialization of the :class:`PINNInterface` class. @@ -41,28 +37,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. :param kwargs: Additional keyword arguments to be passed to the - :class:`~pina.solver.solver.SolverInterface` class. + :class:`~pina.solver.supervised_solver.SupervisedSolverInterface` + class. """ + kwargs["use_lt"] = True + super().__init__(**kwargs) - if loss is None: - loss = torch.nn.MSELoss() - - super().__init__(problem=problem, use_lt=True, **kwargs) - - # check consistency - check_consistency(loss, (LossInterface, _Loss), subclass=False) - - # assign variables - self._loss_fn = loss - - # inverse problem handling - if isinstance(self.problem, InverseProblem): - self._params = self.problem.unknown_parameters - self._clamp_params = self._clamp_inverse_problem_params - else: - self._params = None - self._clamp_params = lambda: None - + # current condition name self.__metric = None def optimization_cycle(self, batch, loss_residuals=None): @@ -103,8 +84,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): ) # append loss condition_loss[condition_name] = loss - # clamp unknown parameters in InverseProblem (if needed) - self._clamp_params() return condition_loss @torch.set_grad_enabled(True) @@ -135,7 +114,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): """ return super().test_step(batch, loss_residuals=self._residual_loss) - @abstractmethod def loss_data(self, input, target): """ Compute the data loss for the PINN solver by evaluating the loss @@ -147,7 +125,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): network's output. :return: The supervised loss, averaged over the number of observations. :rtype: LabelTensor + :raises NotImplementedError: If the method is not implemented. """ + raise NotImplementedError( + "PINN is being used in a supervised learning context, but the " + "'loss_data' method has not been implemented. " + ) @abstractmethod def loss_phys(self, samples, equation): @@ -196,26 +179,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): residuals = self.compute_residual(samples, equation) return self._loss_fn(residuals, torch.zeros_like(residuals)) - def _clamp_inverse_problem_params(self): - """ - Clamps the parameters of the inverse problem solver to specified ranges. - """ - for v in self._params: - self._params[v].data.clamp_( - self.problem.unknown_parameter_domain.range_[v][0], - self.problem.unknown_parameter_domain.range_[v][1], - ) - - @property - def loss(self): - """ - The loss used for training. - - :return: The loss function used for training. - :rtype: torch.nn.Module - """ - return self._loss_fn - @property def current_condition_name(self): """ diff --git a/pina/solver/solver.py b/pina/solver/solver.py index 99de1df..f6bcc2a 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -5,7 +5,7 @@ import lightning import torch from torch._dynamo import OptimizedModule -from ..problem import AbstractProblem +from ..problem import AbstractProblem, InverseProblem from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler from ..loss import WeightingInterface from ..loss.scalar_weighting import _NoWeighting @@ -64,6 +64,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): self._pina_optimizers = None self._pina_schedulers = None + # inverse problem handling + if isinstance(self.problem, InverseProblem): + self._params = self.problem.unknown_parameters + self._clamp_params = self._clamp_inverse_problem_params + else: + self._params = None + self._clamp_params = lambda: None + @abstractmethod def forward(self, *args, **kwargs): """ @@ -231,14 +239,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): containing the condition name and the associated scalar loss. :rtype: dict """ + # compute losses losses = self.optimization_cycle(batch) + # clamp unknown parameters in InverseProblem (if needed) + self._clamp_params() + # store log for name, value in losses.items(): self.store_log( f"{name}_loss", value.item(), self.get_batch_size(batch) ) + # aggregate loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) return loss + def _clamp_inverse_problem_params(self): + """ + Clamps the parameters of the inverse problem solver to specified ranges. + """ + for v in self._params: + self._params[v].data.clamp_( + self.problem.unknown_parameter_domain.range_[v][0], + self.problem.unknown_parameter_domain.range_[v][1], + ) + @staticmethod def _compile_modules(model): """ @@ -405,6 +428,15 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): :rtype: tuple[list[Optimizer], list[Scheduler]] """ self.optimizer.hook(self.model.parameters()) + if isinstance(self.problem, InverseProblem): + self.optimizer.instance.add_param_group( + { + "params": [ + self._params[var] + for var in self.problem.unknown_variables + ] + } + ) self.scheduler.hook(self.optimizer) return ([self.optimizer.instance], [self.scheduler.instance])