fix compile issue (#627)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
"""Module for the Physics-Informed Neural Network Interface."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import warnings
|
||||
import torch
|
||||
|
||||
from ...utils import custom_warning_format
|
||||
from ..supervised_solver import SupervisedSolverInterface
|
||||
from ...condition import (
|
||||
InputTargetCondition,
|
||||
@@ -10,6 +12,10 @@ from ...condition import (
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
# set the warning for torch >= 2.8 compile
|
||||
warnings.formatwarning = custom_warning_format
|
||||
warnings.filterwarnings("always", category=UserWarning)
|
||||
|
||||
|
||||
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
@@ -46,6 +52,36 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
||||
# current condition name
|
||||
self.__metric = None
|
||||
|
||||
def setup(self, stage):
|
||||
"""
|
||||
Setup method executed at the beginning of training and testing.
|
||||
|
||||
This method compiles the model only if the installed torch version
|
||||
is earlier than 2.8, due to known issues with later versions
|
||||
(see https://github.com/mathLab/PINA/issues/621).
|
||||
|
||||
.. warning::
|
||||
For torch >= 2.8, compilation is disabled. Forcing compilation
|
||||
on these versions may cause runtime errors or unstable behavior.
|
||||
|
||||
:param str stage: The current stage of the training process
|
||||
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
|
||||
:return: The result of the parent class ``setup`` method.
|
||||
:rtype: Any
|
||||
"""
|
||||
# Override the compilation, compiling only for torch < 2.8, see
|
||||
# related issue at https://github.com/mathLab/PINA/issues/621
|
||||
if torch.__version__ < "2.8":
|
||||
self.trainer.compile = True
|
||||
else:
|
||||
self.trainer.compile = False
|
||||
warnings.warn(
|
||||
"Compilation is disabled for torch >= 2.8. "
|
||||
"Forcing compilation may cause runtime errors or instability.",
|
||||
UserWarning,
|
||||
)
|
||||
return super().setup(stage)
|
||||
|
||||
def optimization_cycle(self, batch, loss_residuals=None):
|
||||
"""
|
||||
The optimization cycle for the PINN solver.
|
||||
|
||||
@@ -169,7 +169,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
compile the model if the :class:`~pina.trainer.Trainer`
|
||||
``compile`` is ``True``.
|
||||
|
||||
|
||||
:param str stage: The current stage of the training process
|
||||
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
|
||||
:return: The result of the parent class ``setup`` method.
|
||||
:rtype: Any
|
||||
"""
|
||||
if stage == "fit" and self.trainer.compile:
|
||||
self._setup_compile()
|
||||
|
||||
Reference in New Issue
Block a user