Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -1,15 +1,13 @@
""" Module for SupervisedSolver """
import torch
from torch.nn.modules.loss import _Loss
from ..optim import TorchOptimizer, TorchScheduler
from .solver import SolverInterface
from ..label_tensor import LabelTensor
from .solver import SingleSolverInterface
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface):
class SupervisedSolver(SingleSolverInterface):
r"""
SupervisedSolver solver class. This class implements a SupervisedSolver,
using a user specified ``model`` to solve a specific ``problem``.
@@ -46,110 +44,54 @@ class SupervisedSolver(SolverInterface):
loss=None,
optimizer=None,
scheduler=None,
extra_features=None,
weighting=None,
use_lt=True):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.nn.Module extra_features: The additional input
features to use as augmented input.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
"""
if loss is None:
loss = torch.nn.MSELoss()
if optimizer is None:
optimizer = TorchOptimizer(torch.optim.Adam, lr=0.001)
if scheduler is None:
scheduler = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
super().__init__(models=model,
super().__init__(model=model,
problem=problem,
optimizers=optimizer,
schedulers=scheduler,
extra_features=extra_features,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
use_lt=use_lt)
# check consistency
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
subclass=False)
self._loss = loss
self._model = self._pina_models[0]
self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0]
self.validation_condition_losses = {
k: {'loss': [],
'count': []} for k in self.problem.conditions.keys()}
def forward(self, x):
"""Forward pass implementation for the solver.
:param torch.Tensor x: Input tensor.
:return: Solver solution.
:rtype: torch.Tensor
def optimization_cycle(self, batch):
"""
Perform an optimization cycle by computing the loss for each condition
in the given batch.
output = self._model(x)
output.labels = self.problem.output_variables
return output
def configure_optimizers(self):
"""Optimizer configuration for the solver.
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
:param batch: A batch of data, where each element is a tuple containing
a condition name and a dictionary of points.
:type batch: list of tuples (str, dict)
:return: The computed loss for the all conditions in the batch,
cast to a subclass of `torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
"""
self._optimizer.hook(self._model.parameters())
self._scheduler.hook(self._optimizer)
return ([self._optimizer.optimizer_instance],
[self._scheduler.scheduler_instance])
def training_step(self, batch):
"""Solver training step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param batch_idx: The batch index.
:type batch_idx: int
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
condition_loss = []
condition_loss = {}
for condition_name, points in batch:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
loss = sum(condition_loss)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True,
batch_size=self.get_batch_size(batch), sync_dist=True)
return loss
def validation_step(self, batch):
"""
Solver validation step.
"""
condition_loss = []
for condition_name, points in batch:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
loss = sum(condition_loss)
self.log('val_loss', loss, prog_bar=True, logger=True,
batch_size=self.get_batch_size(batch), sync_dist=True)
def test_step(self, batch, batch_idx):
"""
Solver test step.
"""
raise NotImplementedError("Test step not implemented yet.")
condition_loss[condition_name] = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
return condition_loss
def loss_data(self, input_pts, output_pts):
"""
@@ -157,38 +99,19 @@ class SupervisedSolver(SolverInterface):
the network output against the true solution. This function
should not be override if not intentionally.
:param LabelTensor input_pts: The input to the neural networks.
:param LabelTensor output_pts: The true solution to compare the
:param input_pts: The input to the neural networks.
:type input_pts: LabelTensor | torch.Tensor
:param output_pts: The true solution to compare the
network solution.
:return: The residual loss averaged on the input coordinates
:type output_pts: LabelTensor | torch.Tensor
:return: The residual loss.
:rtype: torch.Tensor
"""
return self._loss(self.forward(input_pts), output_pts)
@property
def scheduler(self):
"""
Scheduler for training.
"""
return self._scheduler
@property
def optimizer(self):
"""
Optimizer for training.
"""
return self._optimizer
@property
def model(self):
"""
Neural network for training.
"""
return self._model
@property
def loss(self):
"""
Loss for training.
"""
return self._loss
return self._loss