refact
This commit is contained in:
@@ -1,21 +1,14 @@
|
||||
""" Module for SupervisedSolver """
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import (
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
||||
from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss import LossInterface
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
@@ -51,12 +44,9 @@ class SupervisedSolver(SolverInterface):
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
@@ -73,24 +63,26 @@ class SupervisedSolver(SolverInterface):
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
if optimizer is None:
|
||||
optimizer = TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = TorchScheduler(
|
||||
torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
super().__init__(
|
||||
models=[model],
|
||||
model=model,
|
||||
problem=problem,
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
extra_features=extra_features,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict)
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
|
||||
# assign variables
|
||||
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
||||
self._loss = loss
|
||||
self._neural_net = self.models[0]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
|
||||
@@ -98,7 +90,7 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.neural_net(x)
|
||||
return self._pina_model(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the solver.
|
||||
@@ -106,7 +98,9 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
return self.optimizers, [self.scheduler]
|
||||
self._pina_optimizer.hook(self._pina_model.parameters())
|
||||
self._pina_scheduler.hook(self._pina_optimizer)
|
||||
return self._pina_optimizer, self._pina_scheduler
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Solver training step.
|
||||
@@ -168,14 +162,21 @@ class SupervisedSolver(SolverInterface):
|
||||
"""
|
||||
Scheduler for training.
|
||||
"""
|
||||
return self._scheduler
|
||||
return self._pina_scheduler
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
Optimizer for training.
|
||||
"""
|
||||
return self._pina_optimizer
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
def model(self):
|
||||
"""
|
||||
Neural network for training.
|
||||
"""
|
||||
return self._neural_net
|
||||
return self._pina_model
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
|
||||
Reference in New Issue
Block a user