disable compilation py>=3.14
This commit is contained in:
committed by
Giovanni Canali
parent
24d806b262
commit
64930c431f
@@ -71,9 +71,7 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
# Override the compilation, compiling only for torch < 2.8, see
|
||||
# related issue at https://github.com/mathLab/PINA/issues/621
|
||||
if torch.__version__ < "2.8":
|
||||
self.trainer.compile = True
|
||||
else:
|
||||
if torch.__version__ >= "2.8":
|
||||
self.trainer.compile = False
|
||||
warnings.warn(
|
||||
"Compilation is disabled for torch >= 2.8. "
|
||||
|
||||
@@ -174,11 +174,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
:return: The result of the parent class ``setup`` method.
|
||||
:rtype: Any
|
||||
"""
|
||||
if stage == "fit" and self.trainer.compile:
|
||||
self._setup_compile()
|
||||
if stage == "test" and (
|
||||
self.trainer.compile and not self._is_compiled()
|
||||
):
|
||||
if self.trainer.compile and not self._is_compiled():
|
||||
self._setup_compile()
|
||||
return super().setup(stage)
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
"""Module for the Trainer."""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import torch
|
||||
import lightning
|
||||
from .utils import check_consistency
|
||||
from .utils import check_consistency, custom_warning_format
|
||||
from .data import PinaDataModule
|
||||
from .solver import SolverInterface, PINNInterface
|
||||
|
||||
# set the warning for compile options
|
||||
warnings.formatwarning = custom_warning_format
|
||||
warnings.filterwarnings("always", category=UserWarning)
|
||||
|
||||
|
||||
class Trainer(lightning.pytorch.Trainer):
|
||||
"""
|
||||
@@ -49,7 +54,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset. Default is ``0.0``.
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
Default is ``False``. For Windows users, it is always disabled.
|
||||
Default is ``False``. For Windows users, it is always disabled. Not
|
||||
supported for python version greater or equal than 3.14.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training. For further details, see the
|
||||
:class:`~pina.data.data_module.PinaDataModule` class. Default is
|
||||
@@ -104,8 +110,17 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# checking compilation and automatic batching
|
||||
if compile is None or sys.platform == "win32":
|
||||
# compilation disabled for Windows and for Python 3.14+
|
||||
if (
|
||||
compile is None
|
||||
or sys.platform == "win32"
|
||||
or sys.version_info >= (3, 14)
|
||||
):
|
||||
compile = False
|
||||
warnings.warn(
|
||||
"Compilation is disabled for Python 3.14+ and for Windows.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
repeat = repeat if repeat is not None else False
|
||||
|
||||
@@ -325,3 +340,23 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
return pin_memory, num_workers, shuffle, batch_size
|
||||
|
||||
@property
|
||||
def compile(self):
|
||||
"""
|
||||
Whether compilation is required or not.
|
||||
|
||||
:return: ``True`` if compilation is required, ``False`` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return self._compile
|
||||
|
||||
@compile.setter
|
||||
def compile(self, value):
|
||||
"""
|
||||
Setting the value of compile.
|
||||
|
||||
:param bool value: Whether compilation is required or not.
|
||||
"""
|
||||
check_consistency(value, bool)
|
||||
self._compile = value
|
||||
|
||||
Reference in New Issue
Block a user