Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes * Refactoring SolverInterfaces (#435) * Solver update + weighting * Updating PINN for 0.2 * Modify GAROM + tests * Adding more versatile loggers * Disable compilation when running on Windows * Fix tests --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
committed by
Nicola Demo
parent
780c4921eb
commit
9cae9a438f
@@ -1,15 +1,13 @@
|
||||
""" Module for SupervisedSolver """
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from ..optim import TorchOptimizer, TorchScheduler
|
||||
from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from .solver import SingleSolverInterface
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
class SupervisedSolver(SingleSolverInterface):
|
||||
r"""
|
||||
SupervisedSolver solver class. This class implements a SupervisedSolver,
|
||||
using a user specified ``model`` to solve a specific ``problem``.
|
||||
@@ -46,110 +44,54 @@ class SupervisedSolver(SolverInterface):
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None,
|
||||
weighting=None,
|
||||
use_lt=True):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param WeightingInterface weighting: The loss weighting to use.
|
||||
:param bool use_lt: Using LabelTensors as input during training.
|
||||
"""
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
if optimizer is None:
|
||||
optimizer = TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
super().__init__(models=model,
|
||||
super().__init__(model=model,
|
||||
problem=problem,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
use_lt=use_lt)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
|
||||
subclass=False)
|
||||
self._loss = loss
|
||||
self._model = self._pina_models[0]
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
self.validation_condition_losses = {
|
||||
k: {'loss': [],
|
||||
'count': []} for k in self.problem.conditions.keys()}
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
def optimization_cycle(self, batch):
|
||||
"""
|
||||
Perform an optimization cycle by computing the loss for each condition
|
||||
in the given batch.
|
||||
|
||||
output = self._model(x)
|
||||
|
||||
output.labels = self.problem.output_variables
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
:param batch: A batch of data, where each element is a tuple containing
|
||||
a condition name and a dictionary of points.
|
||||
:type batch: list of tuples (str, dict)
|
||||
:return: The computed loss for the all conditions in the batch,
|
||||
cast to a subclass of `torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict(torch.Tensor)
|
||||
"""
|
||||
self._optimizer.hook(self._model.parameters())
|
||||
self._scheduler.hook(self._optimizer)
|
||||
return ([self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance])
|
||||
|
||||
def training_step(self, batch):
|
||||
"""Solver training step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:param batch_idx: The batch index.
|
||||
:type batch_idx: int
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
condition_loss = []
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
loss = sum(condition_loss)
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True,
|
||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
Solver validation step.
|
||||
"""
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
loss = sum(condition_loss)
|
||||
self.log('val_loss', loss, prog_bar=True, logger=True,
|
||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
"""
|
||||
Solver test step.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("Test step not implemented yet.")
|
||||
condition_loss[condition_name] = self.loss_data(
|
||||
input_pts=input_pts, output_pts=output_pts)
|
||||
return condition_loss
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
@@ -157,38 +99,19 @@ class SupervisedSolver(SolverInterface):
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_pts: The input to the neural networks.
|
||||
:param LabelTensor output_pts: The true solution to compare the
|
||||
:param input_pts: The input to the neural networks.
|
||||
:type input_pts: LabelTensor | torch.Tensor
|
||||
:param output_pts: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:type output_pts: LabelTensor | torch.Tensor
|
||||
:return: The residual loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Scheduler for training.
|
||||
"""
|
||||
return self._scheduler
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
Optimizer for training.
|
||||
"""
|
||||
return self._optimizer
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
Neural network for training.
|
||||
"""
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
Loss for training.
|
||||
"""
|
||||
return self._loss
|
||||
return self._loss
|
||||
Reference in New Issue
Block a user