diff --git a/pina/optimizer.py b/pina/optimizer.py deleted file mode 100644 index 08631e6..0000000 --- a/pina/optimizer.py +++ /dev/null @@ -1,21 +0,0 @@ -""" Module for PINA Optimizer """ - -import torch -from .utils import check_consistency - -class Optimizer: # TODO improve interface - pass - - -class TorchOptimizer(Optimizer): - - def __init__(self, optimizer_class, **kwargs): - check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True) - - self.optimizer_class = optimizer_class - self.kwargs = kwargs - - def hook(self, parameters): - self.optimizer_instance = self.optimizer_class( - parameters, **self.kwargs - ) \ No newline at end of file diff --git a/pina/scheduler.py b/pina/scheduler.py deleted file mode 100644 index 563f829..0000000 --- a/pina/scheduler.py +++ /dev/null @@ -1,29 +0,0 @@ -""" Module for PINA Scheduler """ - -try: - from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 -except ImportError: - from torch.optim.lr_scheduler import ( - _LRScheduler as LRScheduler, - ) # torch < 2.0 -from .optimizer import Optimizer -from .utils import check_consistency - - -class Scheduler: # TODO improve interface - pass - - -class TorchScheduler(Scheduler): - - def __init__(self, scheduler_class, **kwargs): - check_consistency(scheduler_class, LRScheduler, subclass=True) - - self.scheduler_class = scheduler_class - self.kwargs = kwargs - - def hook(self, optimizer): - check_consistency(optimizer, Optimizer) - self.scheduler_instance = self.scheduler_class( - optimizer.optimizer_instance, **self.kwargs - )