disable compilation py>=3.14

This commit is contained in:
Dario Coscia
2025-10-28 12:29:19 +01:00
committed by Giovanni Canali
parent 24d806b262
commit 64930c431f
3 changed files with 40 additions and 11 deletions

View File

@@ -71,9 +71,7 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
""" """
# Override the compilation, compiling only for torch < 2.8, see # Override the compilation, compiling only for torch < 2.8, see
# related issue at https://github.com/mathLab/PINA/issues/621 # related issue at https://github.com/mathLab/PINA/issues/621
if torch.__version__ < "2.8": if torch.__version__ >= "2.8":
self.trainer.compile = True
else:
self.trainer.compile = False self.trainer.compile = False
warnings.warn( warnings.warn(
"Compilation is disabled for torch >= 2.8. " "Compilation is disabled for torch >= 2.8. "

View File

@@ -174,11 +174,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
:return: The result of the parent class ``setup`` method. :return: The result of the parent class ``setup`` method.
:rtype: Any :rtype: Any
""" """
if stage == "fit" and self.trainer.compile: if self.trainer.compile and not self._is_compiled():
self._setup_compile()
if stage == "test" and (
self.trainer.compile and not self._is_compiled()
):
self._setup_compile() self._setup_compile()
return super().setup(stage) return super().setup(stage)

View File

@@ -1,12 +1,17 @@
"""Module for the Trainer.""" """Module for the Trainer."""
import sys import sys
import warnings
import torch import torch
import lightning import lightning
from .utils import check_consistency from .utils import check_consistency, custom_warning_format
from .data import PinaDataModule from .data import PinaDataModule
from .solver import SolverInterface, PINNInterface from .solver import SolverInterface, PINNInterface
# set the warning for compile options
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=UserWarning)
class Trainer(lightning.pytorch.Trainer): class Trainer(lightning.pytorch.Trainer):
""" """
@@ -49,7 +54,8 @@ class Trainer(lightning.pytorch.Trainer):
:param float val_size: The percentage of elements to include in the :param float val_size: The percentage of elements to include in the
validation dataset. Default is ``0.0``. validation dataset. Default is ``0.0``.
:param bool compile: If ``True``, the model is compiled before training. :param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled. Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14.
:param bool repeat: Whether to repeat the dataset data in each :param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the condition during training. For further details, see the
:class:`~pina.data.data_module.PinaDataModule` class. Default is :class:`~pina.data.data_module.PinaDataModule` class. Default is
@@ -104,8 +110,17 @@ class Trainer(lightning.pytorch.Trainer):
super().__init__(**kwargs) super().__init__(**kwargs)
# checking compilation and automatic batching # checking compilation and automatic batching
if compile is None or sys.platform == "win32": # compilation disabled for Windows and for Python 3.14+
if (
compile is None
or sys.platform == "win32"
or sys.version_info >= (3, 14)
):
compile = False compile = False
warnings.warn(
"Compilation is disabled for Python 3.14+ and for Windows.",
UserWarning,
)
repeat = repeat if repeat is not None else False repeat = repeat if repeat is not None else False
@@ -325,3 +340,23 @@ class Trainer(lightning.pytorch.Trainer):
if batch_size is not None: if batch_size is not None:
check_consistency(batch_size, int) check_consistency(batch_size, int)
return pin_memory, num_workers, shuffle, batch_size return pin_memory, num_workers, shuffle, batch_size
@property
def compile(self):
"""
Whether compilation is required or not.
:return: ``True`` if compilation is required, ``False`` otherwise.
:rtype: bool
"""
return self._compile
@compile.setter
def compile(self, value):
"""
Setting the value of compile.
:param bool value: Whether compilation is required or not.
"""
check_consistency(value, bool)
self._compile = value