optimizer and scheduler classes

This commit is contained in:
Nicola Demo
2024-06-21 14:37:55 +02:00
parent 2b71e0148d
commit b7d512e8bf
5 changed files with 102 additions and 1 deletions

View File

@@ -6,13 +6,17 @@ __all__ = [
"Condition", "Condition",
"SamplePointDataset", "SamplePointDataset",
"SamplePointLoader", "SamplePointLoader",
"TorchOptimizer",
"TorchScheduler",
] ]
from .meta import * from .meta import *
#from .label_tensor import LabelTensor from .label_tensor import LabelTensor
from .solvers.solver import SolverInterface from .solvers.solver import SolverInterface
from .trainer import Trainer from .trainer import Trainer
from .plotter import Plotter from .plotter import Plotter
from .condition import Condition from .condition import Condition
from .dataset import SamplePointDataset from .dataset import SamplePointDataset
from .dataset import SamplePointLoader from .dataset import SamplePointLoader
from .optimizer import TorchOptimizer
from .scheduler import TorchScheduler

21
pina/optimizer.py Normal file
View File

@@ -0,0 +1,21 @@
""" Module for PINA Optimizer """
import torch
from .utils import check_consistency
class Optimizer: # TODO improve interface
pass
class TorchOptimizer(Optimizer):
def __init__(self, optimizer_class, **kwargs):
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)
self.optimizer_class = optimizer_class
self.kwargs = kwargs
def hook(self, parameters):
self.optimizer_instance = self.optimizer_class(
parameters, **self.kwargs
)

29
pina/scheduler.py Normal file
View File

@@ -0,0 +1,29 @@
""" Module for PINA Scheduler """
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import (
_LRScheduler as LRScheduler,
) # torch < 2.0
from .optimizer import Optimizer
from .utils import check_consistency
class Scheduler: # TODO improve interface
pass
class TorchScheduler(Scheduler):
def __init__(self, scheduler_class, **kwargs):
check_consistency(scheduler_class, LRScheduler, subclass=True)
self.scheduler_class = scheduler_class
self.kwargs = kwargs
def hook(self, optimizer):
check_consistency(optimizer, Optimizer)
self.scheduler_instance = self.scheduler_class(
optimizer.optimizer_instance, **self.kwargs
)

20
tests/test_optimizer.py Normal file
View File

@@ -0,0 +1,20 @@
import torch
import pytest
from pina import TorchOptimizer
opt_list = [
torch.optim.Adam,
torch.optim.AdamW,
torch.optim.SGD,
torch.optim.RMSprop
]
@pytest.mark.parametrize("optimizer_class", opt_list)
def test_constructor(optimizer_class):
TorchOptimizer(optimizer_class, lr=1e-3)
@pytest.mark.parametrize("optimizer_class", opt_list)
def test_hook(optimizer_class):
opt = TorchOptimizer(optimizer_class, lr=1e-3)
opt.hook(torch.nn.Linear(10, 10).parameters())

27
tests/test_scheduler.py Normal file
View File

@@ -0,0 +1,27 @@
import torch
import pytest
from pina import TorchOptimizer, TorchScheduler
opt_list = [
torch.optim.Adam,
torch.optim.AdamW,
torch.optim.SGD,
torch.optim.RMSprop
]
sch_list = [
torch.optim.lr_scheduler.ConstantLR
]
@pytest.mark.parametrize("scheduler_class", sch_list)
def test_constructor(scheduler_class):
TorchScheduler(scheduler_class)
@pytest.mark.parametrize("optimizer_class", opt_list)
@pytest.mark.parametrize("scheduler_class", sch_list)
def test_hook(optimizer_class, scheduler_class):
opt = TorchOptimizer(optimizer_class, lr=1e-3)
opt.hook(torch.nn.Linear(10, 10).parameters())
sch = TorchScheduler(scheduler_class)
sch.hook(opt)