From 74695434992bfd8e06e8f78da8ac6b0d83f1632c Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Thu, 4 Sep 2025 14:55:37 +0200 Subject: [PATCH] simplify kwargs logic for equations --- pina/equation/equation.py | 20 ++++++++++++++----- .../physics_informed_solver/pinn_interface.py | 10 +++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pina/equation/equation.py b/pina/equation/equation.py index 60b538e..1e4622d 100644 --- a/pina/equation/equation.py +++ b/pina/equation/equation.py @@ -1,5 +1,7 @@ """Module for the Equation.""" +import inspect + from .equation_interface import EquationInterface @@ -25,6 +27,9 @@ class Equation(EquationInterface): "Expected a callable function, got " f"{equation}" ) + # compute the signature + sig = inspect.signature(equation) + self.__len_sig = len(sig.parameters) self.__equation = equation def residual(self, input_, output_, params_=None): @@ -41,9 +46,14 @@ class Equation(EquationInterface): parameters must be initialized to ``None``. Default is ``None``. :return: The computed residual of the equation. :rtype: LabelTensor + :raises RuntimeError: If the underlying equation signature length is not + 2 (direct problem) or 3 (inverse problem). """ - if params_ is None: - result = self.__equation(input_, output_) - else: - result = self.__equation(input_, output_, params_) - return result + if self.__len_sig == 2: + return self.__equation(input_, output_) + if self.__len_sig == 3: + return self.__equation(input_, output_, params_) + raise RuntimeError( + f"Unexpected number of arguments in equation: {self.__len_sig}. " + "Expected either 2 (direct problem) or 3 (inverse problem)." + ) diff --git a/pina/solver/physics_informed_solver/pinn_interface.py b/pina/solver/physics_informed_solver/pinn_interface.py index 535e7ae..9155e19 100644 --- a/pina/solver/physics_informed_solver/pinn_interface.py +++ b/pina/solver/physics_informed_solver/pinn_interface.py @@ -190,13 +190,9 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta): :return: The residual of the solution of the model. :rtype: LabelTensor """ - try: - residual = equation.residual(samples, self.forward(samples)) - except TypeError: - # this occurs when the function has three inputs (inverse problem) - residual = equation.residual( - samples, self.forward(samples), self._params - ) + residual = equation.residual( + samples, self.forward(samples), self._params + ) return residual def _residual_loss(self, samples, equation):