optimizer and scheduler classes

This commit is contained in:
Nicola Demo
2024-06-21 14:37:55 +02:00
parent 2b71e0148d
commit b7d512e8bf
5 changed files with 102 additions and 1 deletions

20
tests/test_optimizer.py Normal file
View File

@@ -0,0 +1,20 @@
import torch
import pytest
from pina import TorchOptimizer
opt_list = [
torch.optim.Adam,
torch.optim.AdamW,
torch.optim.SGD,
torch.optim.RMSprop
]
@pytest.mark.parametrize("optimizer_class", opt_list)
def test_constructor(optimizer_class):
TorchOptimizer(optimizer_class, lr=1e-3)
@pytest.mark.parametrize("optimizer_class", opt_list)
def test_hook(optimizer_class):
opt = TorchOptimizer(optimizer_class, lr=1e-3)
opt.hook(torch.nn.Linear(10, 10).parameters())