fix compile issue (#627)

This commit is contained in:
Dario Coscia
2025-09-08 14:22:25 +02:00
committed by GitHub
parent efc9e327f6
commit 973d0c05ab
2 changed files with 40 additions and 1 deletions

View File

@@ -1,8 +1,10 @@
"""Module for the Physics-Informed Neural Network Interface."""
from abc import ABCMeta, abstractmethod
import warnings
import torch
from ...utils import custom_warning_format
from ..supervised_solver import SupervisedSolverInterface
from ...condition import (
InputTargetCondition,
@@ -10,6 +12,10 @@ from ...condition import (
DomainEquationCondition,
)
# set the warning for torch >= 2.8 compile
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=UserWarning)
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
"""
@@ -46,6 +52,36 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
# current condition name
self.__metric = None
def setup(self, stage):
"""
Setup method executed at the beginning of training and testing.
This method compiles the model only if the installed torch version
is earlier than 2.8, due to known issues with later versions
(see https://github.com/mathLab/PINA/issues/621).
.. warning::
For torch >= 2.8, compilation is disabled. Forcing compilation
on these versions may cause runtime errors or unstable behavior.
:param str stage: The current stage of the training process
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
:return: The result of the parent class ``setup`` method.
:rtype: Any
"""
# Override the compilation, compiling only for torch < 2.8, see
# related issue at https://github.com/mathLab/PINA/issues/621
if torch.__version__ < "2.8":
self.trainer.compile = True
else:
self.trainer.compile = False
warnings.warn(
"Compilation is disabled for torch >= 2.8. "
"Forcing compilation may cause runtime errors or instability.",
UserWarning,
)
return super().setup(stage)
def optimization_cycle(self, batch, loss_residuals=None):
"""
The optimization cycle for the PINN solver.