Files
PINA/pina/optim/torch_scheduler.py
Dario Coscia 9cae9a438f Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
2025-03-19 17:46:35 +01:00

35 lines
960 B
Python

""" Module for PINA Torch Optimizer """
import torch
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import (
_LRScheduler as LRScheduler, ) # torch < 2.0
from ..utils import check_consistency
from .optimizer_interface import Optimizer
from .scheduler_interface import Scheduler
class TorchScheduler(Scheduler):
def __init__(self, scheduler_class, **kwargs):
check_consistency(scheduler_class, LRScheduler, subclass=True)
self.scheduler_class = scheduler_class
self.kwargs = kwargs
self._scheduler_instance = None
def hook(self, optimizer):
check_consistency(optimizer, Optimizer)
self._scheduler_instance = self.scheduler_class(
optimizer.instance, **self.kwargs)
@property
def instance(self):
"""
Scheduler instance.
"""
return self._scheduler_instance