From 973d0c05ab23c7680288c758899ad606a6d82d30 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:22:25 +0200 Subject: [PATCH] fix compile issue (#627) --- .../physics_informed_solver/pinn_interface.py | 36 +++++++++++++++++++ pina/solver/solver.py | 5 ++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/pina/solver/physics_informed_solver/pinn_interface.py b/pina/solver/physics_informed_solver/pinn_interface.py index 976f6ce..535e7ae 100644 --- a/pina/solver/physics_informed_solver/pinn_interface.py +++ b/pina/solver/physics_informed_solver/pinn_interface.py @@ -1,8 +1,10 @@ """Module for the Physics-Informed Neural Network Interface.""" from abc import ABCMeta, abstractmethod +import warnings import torch +from ...utils import custom_warning_format from ..supervised_solver import SupervisedSolverInterface from ...condition import ( InputTargetCondition, @@ -10,6 +12,10 @@ from ...condition import ( DomainEquationCondition, ) +# set the warning for torch >= 2.8 compile +warnings.formatwarning = custom_warning_format +warnings.filterwarnings("always", category=UserWarning) + class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta): """ @@ -46,6 +52,36 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta): # current condition name self.__metric = None + def setup(self, stage): + """ + Setup method executed at the beginning of training and testing. + + This method compiles the model only if the installed torch version + is earlier than 2.8, due to known issues with later versions + (see https://github.com/mathLab/PINA/issues/621). + + .. warning:: + For torch >= 2.8, compilation is disabled. Forcing compilation + on these versions may cause runtime errors or unstable behavior. + + :param str stage: The current stage of the training process + (e.g., ``fit``, ``validate``, ``test``, ``predict``). + :return: The result of the parent class ``setup`` method. + :rtype: Any + """ + # Override the compilation, compiling only for torch < 2.8, see + # related issue at https://github.com/mathLab/PINA/issues/621 + if torch.__version__ < "2.8": + self.trainer.compile = True + else: + self.trainer.compile = False + warnings.warn( + "Compilation is disabled for torch >= 2.8. " + "Forcing compilation may cause runtime errors or instability.", + UserWarning, + ) + return super().setup(stage) + def optimization_cycle(self, batch, loss_residuals=None): """ The optimization cycle for the PINN solver. diff --git a/pina/solver/solver.py b/pina/solver/solver.py index f6bcc2a..f3ff405 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -169,7 +169,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): compile the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. - + :param str stage: The current stage of the training process + (e.g., ``fit``, ``validate``, ``test``, ``predict``). + :return: The result of the parent class ``setup`` method. + :rtype: Any """ if stage == "fit" and self.trainer.compile: self._setup_compile()