fix compile issue (#627)
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
"""Module for the Physics-Informed Neural Network Interface."""
|
"""Module for the Physics-Informed Neural Network Interface."""
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ...utils import custom_warning_format
|
||||||
from ..supervised_solver import SupervisedSolverInterface
|
from ..supervised_solver import SupervisedSolverInterface
|
||||||
from ...condition import (
|
from ...condition import (
|
||||||
InputTargetCondition,
|
InputTargetCondition,
|
||||||
@@ -10,6 +12,10 @@ from ...condition import (
|
|||||||
DomainEquationCondition,
|
DomainEquationCondition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# set the warning for torch >= 2.8 compile
|
||||||
|
warnings.formatwarning = custom_warning_format
|
||||||
|
warnings.filterwarnings("always", category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
@@ -46,6 +52,36 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
|||||||
# current condition name
|
# current condition name
|
||||||
self.__metric = None
|
self.__metric = None
|
||||||
|
|
||||||
|
def setup(self, stage):
|
||||||
|
"""
|
||||||
|
Setup method executed at the beginning of training and testing.
|
||||||
|
|
||||||
|
This method compiles the model only if the installed torch version
|
||||||
|
is earlier than 2.8, due to known issues with later versions
|
||||||
|
(see https://github.com/mathLab/PINA/issues/621).
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
For torch >= 2.8, compilation is disabled. Forcing compilation
|
||||||
|
on these versions may cause runtime errors or unstable behavior.
|
||||||
|
|
||||||
|
:param str stage: The current stage of the training process
|
||||||
|
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
|
||||||
|
:return: The result of the parent class ``setup`` method.
|
||||||
|
:rtype: Any
|
||||||
|
"""
|
||||||
|
# Override the compilation, compiling only for torch < 2.8, see
|
||||||
|
# related issue at https://github.com/mathLab/PINA/issues/621
|
||||||
|
if torch.__version__ < "2.8":
|
||||||
|
self.trainer.compile = True
|
||||||
|
else:
|
||||||
|
self.trainer.compile = False
|
||||||
|
warnings.warn(
|
||||||
|
"Compilation is disabled for torch >= 2.8. "
|
||||||
|
"Forcing compilation may cause runtime errors or instability.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
return super().setup(stage)
|
||||||
|
|
||||||
def optimization_cycle(self, batch, loss_residuals=None):
|
def optimization_cycle(self, batch, loss_residuals=None):
|
||||||
"""
|
"""
|
||||||
The optimization cycle for the PINN solver.
|
The optimization cycle for the PINN solver.
|
||||||
|
|||||||
@@ -169,7 +169,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
compile the model if the :class:`~pina.trainer.Trainer`
|
compile the model if the :class:`~pina.trainer.Trainer`
|
||||||
``compile`` is ``True``.
|
``compile`` is ``True``.
|
||||||
|
|
||||||
|
:param str stage: The current stage of the training process
|
||||||
|
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
|
||||||
|
:return: The result of the parent class ``setup`` method.
|
||||||
|
:rtype: Any
|
||||||
"""
|
"""
|
||||||
if stage == "fit" and self.trainer.compile:
|
if stage == "fit" and self.trainer.compile:
|
||||||
self._setup_compile()
|
self._setup_compile()
|
||||||
|
|||||||
Reference in New Issue
Block a user