This commit is contained in:
Nicola Demo
2024-08-05 17:34:34 +02:00
parent 686b557144
commit 5245a0b68c
19 changed files with 483 additions and 173 deletions

View File

@@ -1,21 +1,14 @@
""" Module for SupervisedSolver """
import torch
from torch.nn.modules.loss import _Loss
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import (
_LRScheduler as LRScheduler,
) # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
from .solver import SolverInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss import LossInterface
from torch.nn.modules.loss import _Loss
class SupervisedSolver(SolverInterface):
@@ -51,12 +44,9 @@ class SupervisedSolver(SolverInterface):
self,
problem,
model,
extra_features=None,
loss=torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
loss=None,
optimizer=None,
scheduler=None,
):
"""
:param AbstractProblem problem: The formualation of the problem.
@@ -73,24 +63,26 @@ class SupervisedSolver(SolverInterface):
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
"""
if loss is None:
loss = torch.nn.MSELoss()
if optimizer is None:
optimizer = TorchOptimizer(torch.optim.Adam, lr=0.001)
if scheduler is None:
scheduler = TorchScheduler(
torch.optim.lr_scheduler.ConstantLR)
super().__init__(
models=[model],
model=model,
problem=problem,
optimizers=[optimizer],
optimizers_kwargs=[optimizer_kwargs],
extra_features=extra_features,
optimizer=optimizer,
scheduler=scheduler,
)
# check consistency
check_consistency(scheduler, LRScheduler, subclass=True)
check_consistency(scheduler_kwargs, dict)
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
self._loss = loss
self._neural_net = self.models[0]
def forward(self, x):
"""Forward pass implementation for the solver.
@@ -98,7 +90,7 @@ class SupervisedSolver(SolverInterface):
:return: Solver solution.
:rtype: torch.Tensor
"""
return self.neural_net(x)
return self._pina_model(x)
def configure_optimizers(self):
"""Optimizer configuration for the solver.
@@ -106,7 +98,9 @@ class SupervisedSolver(SolverInterface):
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
return self.optimizers, [self.scheduler]
self._pina_optimizer.hook(self._pina_model.parameters())
self._pina_scheduler.hook(self._pina_optimizer)
return self._pina_optimizer, self._pina_scheduler
def training_step(self, batch, batch_idx):
"""Solver training step.
@@ -168,14 +162,21 @@ class SupervisedSolver(SolverInterface):
"""
Scheduler for training.
"""
return self._scheduler
return self._pina_scheduler
@property
def optimizer(self):
"""
Optimizer for training.
"""
return self._pina_optimizer
@property
def neural_net(self):
def model(self):
"""
Neural network for training.
"""
return self._neural_net
return self._pina_model
@property
def loss(self):