disable compilation py>=3.14
This commit is contained in:
committed by
Giovanni Canali
parent
24d806b262
commit
64930c431f
@@ -71,9 +71,7 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
# Override the compilation, compiling only for torch < 2.8, see
|
# Override the compilation, compiling only for torch < 2.8, see
|
||||||
# related issue at https://github.com/mathLab/PINA/issues/621
|
# related issue at https://github.com/mathLab/PINA/issues/621
|
||||||
if torch.__version__ < "2.8":
|
if torch.__version__ >= "2.8":
|
||||||
self.trainer.compile = True
|
|
||||||
else:
|
|
||||||
self.trainer.compile = False
|
self.trainer.compile = False
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Compilation is disabled for torch >= 2.8. "
|
"Compilation is disabled for torch >= 2.8. "
|
||||||
|
|||||||
@@ -174,11 +174,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
:return: The result of the parent class ``setup`` method.
|
:return: The result of the parent class ``setup`` method.
|
||||||
:rtype: Any
|
:rtype: Any
|
||||||
"""
|
"""
|
||||||
if stage == "fit" and self.trainer.compile:
|
if self.trainer.compile and not self._is_compiled():
|
||||||
self._setup_compile()
|
|
||||||
if stage == "test" and (
|
|
||||||
self.trainer.compile and not self._is_compiled()
|
|
||||||
):
|
|
||||||
self._setup_compile()
|
self._setup_compile()
|
||||||
return super().setup(stage)
|
return super().setup(stage)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
"""Module for the Trainer."""
|
"""Module for the Trainer."""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
import lightning
|
import lightning
|
||||||
from .utils import check_consistency
|
from .utils import check_consistency, custom_warning_format
|
||||||
from .data import PinaDataModule
|
from .data import PinaDataModule
|
||||||
from .solver import SolverInterface, PINNInterface
|
from .solver import SolverInterface, PINNInterface
|
||||||
|
|
||||||
|
# set the warning for compile options
|
||||||
|
warnings.formatwarning = custom_warning_format
|
||||||
|
warnings.filterwarnings("always", category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
class Trainer(lightning.pytorch.Trainer):
|
class Trainer(lightning.pytorch.Trainer):
|
||||||
"""
|
"""
|
||||||
@@ -49,7 +54,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
:param float val_size: The percentage of elements to include in the
|
:param float val_size: The percentage of elements to include in the
|
||||||
validation dataset. Default is ``0.0``.
|
validation dataset. Default is ``0.0``.
|
||||||
:param bool compile: If ``True``, the model is compiled before training.
|
:param bool compile: If ``True``, the model is compiled before training.
|
||||||
Default is ``False``. For Windows users, it is always disabled.
|
Default is ``False``. For Windows users, it is always disabled. Not
|
||||||
|
supported for python version greater or equal than 3.14.
|
||||||
:param bool repeat: Whether to repeat the dataset data in each
|
:param bool repeat: Whether to repeat the dataset data in each
|
||||||
condition during training. For further details, see the
|
condition during training. For further details, see the
|
||||||
:class:`~pina.data.data_module.PinaDataModule` class. Default is
|
:class:`~pina.data.data_module.PinaDataModule` class. Default is
|
||||||
@@ -104,8 +110,17 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# checking compilation and automatic batching
|
# checking compilation and automatic batching
|
||||||
if compile is None or sys.platform == "win32":
|
# compilation disabled for Windows and for Python 3.14+
|
||||||
|
if (
|
||||||
|
compile is None
|
||||||
|
or sys.platform == "win32"
|
||||||
|
or sys.version_info >= (3, 14)
|
||||||
|
):
|
||||||
compile = False
|
compile = False
|
||||||
|
warnings.warn(
|
||||||
|
"Compilation is disabled for Python 3.14+ and for Windows.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
repeat = repeat if repeat is not None else False
|
repeat = repeat if repeat is not None else False
|
||||||
|
|
||||||
@@ -325,3 +340,23 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
if batch_size is not None:
|
if batch_size is not None:
|
||||||
check_consistency(batch_size, int)
|
check_consistency(batch_size, int)
|
||||||
return pin_memory, num_workers, shuffle, batch_size
|
return pin_memory, num_workers, shuffle, batch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def compile(self):
|
||||||
|
"""
|
||||||
|
Whether compilation is required or not.
|
||||||
|
|
||||||
|
:return: ``True`` if compilation is required, ``False`` otherwise.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return self._compile
|
||||||
|
|
||||||
|
@compile.setter
|
||||||
|
def compile(self, value):
|
||||||
|
"""
|
||||||
|
Setting the value of compile.
|
||||||
|
|
||||||
|
:param bool value: Whether compilation is required or not.
|
||||||
|
"""
|
||||||
|
check_consistency(value, bool)
|
||||||
|
self._compile = value
|
||||||
|
|||||||
Reference in New Issue
Block a user