From b7d512e8bf776e05074cc7db0f3d0b97e9473bd7 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Fri, 21 Jun 2024 14:37:55 +0200 Subject: [PATCH] optimizer and scheduler classes --- pina/__init__.py | 6 +++++- pina/optimizer.py | 21 +++++++++++++++++++++ pina/scheduler.py | 29 +++++++++++++++++++++++++++++ tests/test_optimizer.py | 20 ++++++++++++++++++++ tests/test_scheduler.py | 27 +++++++++++++++++++++++++++ 5 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 pina/optimizer.py create mode 100644 pina/scheduler.py create mode 100644 tests/test_optimizer.py create mode 100644 tests/test_scheduler.py diff --git a/pina/__init__.py b/pina/__init__.py index c63440b..7c72533 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -6,13 +6,17 @@ __all__ = [ "Condition", "SamplePointDataset", "SamplePointLoader", + "TorchOptimizer", + "TorchScheduler", ] from .meta import * -#from .label_tensor import LabelTensor +from .label_tensor import LabelTensor from .solvers.solver import SolverInterface from .trainer import Trainer from .plotter import Plotter from .condition import Condition from .dataset import SamplePointDataset from .dataset import SamplePointLoader +from .optimizer import TorchOptimizer +from .scheduler import TorchScheduler \ No newline at end of file diff --git a/pina/optimizer.py b/pina/optimizer.py new file mode 100644 index 0000000..d400e82 --- /dev/null +++ b/pina/optimizer.py @@ -0,0 +1,21 @@ +""" Module for PINA Optimizer """ + +import torch +from .utils import check_consistency + +class Optimizer: # TODO improve interface + pass + + +class TorchOptimizer(Optimizer): + + def __init__(self, optimizer_class, **kwargs): + check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True) + + self.optimizer_class = optimizer_class + self.kwargs = kwargs + + def hook(self, parameters): + self.optimizer_instance = self.optimizer_class( + parameters, **self.kwargs + ) diff --git a/pina/scheduler.py b/pina/scheduler.py new file mode 100644 index 0000000..563f829 --- /dev/null +++ b/pina/scheduler.py @@ -0,0 +1,29 @@ +""" Module for PINA Scheduler """ + +try: + from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 +except ImportError: + from torch.optim.lr_scheduler import ( + _LRScheduler as LRScheduler, + ) # torch < 2.0 +from .optimizer import Optimizer +from .utils import check_consistency + + +class Scheduler: # TODO improve interface + pass + + +class TorchScheduler(Scheduler): + + def __init__(self, scheduler_class, **kwargs): + check_consistency(scheduler_class, LRScheduler, subclass=True) + + self.scheduler_class = scheduler_class + self.kwargs = kwargs + + def hook(self, optimizer): + check_consistency(optimizer, Optimizer) + self.scheduler_instance = self.scheduler_class( + optimizer.optimizer_instance, **self.kwargs + ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..489bbdc --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,20 @@ + +import torch +import pytest +from pina import TorchOptimizer + +opt_list = [ + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.SGD, + torch.optim.RMSprop +] + +@pytest.mark.parametrize("optimizer_class", opt_list) +def test_constructor(optimizer_class): + TorchOptimizer(optimizer_class, lr=1e-3) + +@pytest.mark.parametrize("optimizer_class", opt_list) +def test_hook(optimizer_class): + opt = TorchOptimizer(optimizer_class, lr=1e-3) + opt.hook(torch.nn.Linear(10, 10).parameters()) \ No newline at end of file diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000..4cde13e --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,27 @@ + +import torch +import pytest +from pina import TorchOptimizer, TorchScheduler + +opt_list = [ + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.SGD, + torch.optim.RMSprop +] + +sch_list = [ + torch.optim.lr_scheduler.ConstantLR +] + +@pytest.mark.parametrize("scheduler_class", sch_list) +def test_constructor(scheduler_class): + TorchScheduler(scheduler_class) + +@pytest.mark.parametrize("optimizer_class", opt_list) +@pytest.mark.parametrize("scheduler_class", sch_list) +def test_hook(optimizer_class, scheduler_class): + opt = TorchOptimizer(optimizer_class, lr=1e-3) + opt.hook(torch.nn.Linear(10, 10).parameters()) + sch = TorchScheduler(scheduler_class) + sch.hook(opt) \ No newline at end of file