Refactoring solvers (#541)

* Refactoring solvers

* Simplify logic compile
* Improve and update doc
* Create SupervisedSolverInterface
* Specialize SupervisedSolver and ReducedOrderModelSolver
* Create EnsembleSolverInterface + EnsembleSupervisedSolver
* Create tests ensemble solvers

* formatter

* codacy

* fix issues + speedup test
This commit is contained in:
Dario Coscia
2025-04-09 14:51:42 +02:00
parent 485c8dd789
commit 6dd7bd2825
37 changed files with 1514 additions and 510 deletions

View File

@@ -48,18 +48,18 @@ class GAROM(MultiSolverInterface):
If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1``
is used. Default is ``None``.
:param Optimizer optimizer_generator: The optimizer for the generator.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
discriminator. If `None`, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
discriminator. If ``None``, the :class:`torch.optim.Adam`
optimizer is used. Default is ``None``.
:param Scheduler scheduler_generator: The learning rate scheduler for
the generator.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_discriminator: The learning rate scheduler
for the discriminator.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param float gamma: Ratio of expected loss for generator and
discriminator. Default is ``0.3``.
@@ -88,7 +88,7 @@ class GAROM(MultiSolverInterface):
check_consistency(
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
)
self._loss = loss
self._loss_fn = loss
# set automatic optimization for GANs
self.automatic_optimization = False
@@ -157,10 +157,11 @@ class GAROM(MultiSolverInterface):
generated_snapshots = self.sample(parameters)
# generator loss
r_loss = self._loss(snapshots, generated_snapshots)
r_loss = self._loss_fn(snapshots, generated_snapshots)
d_fake = self.discriminator([generated_snapshots, parameters])
g_loss = (
self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
self._loss_fn(d_fake, generated_snapshots)
+ self.regularizer * r_loss
)
# backward step
@@ -170,24 +171,6 @@ class GAROM(MultiSolverInterface):
return r_loss, g_loss
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
(
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
def _train_discriminator(self, parameters, snapshots):
"""
Train the discriminator model.
@@ -207,8 +190,8 @@ class GAROM(MultiSolverInterface):
d_fake = self.discriminator([generated_snapshots, parameters])
# evaluate loss
d_loss_real = self._loss(d_real, snapshots)
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
d_loss_real = self._loss_fn(d_real, snapshots)
d_loss_fake = self._loss_fn(d_fake, generated_snapshots.detach())
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
@@ -288,7 +271,7 @@ class GAROM(MultiSolverInterface):
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(
condition_loss[condition_name] = self._loss_fn(
snapshots, snapshots_gen
)
loss = self.weighting.aggregate(condition_loss)
@@ -311,7 +294,7 @@ class GAROM(MultiSolverInterface):
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(
condition_loss[condition_name] = self._loss_fn(
snapshots, snapshots_gen
)
loss = self.weighting.aggregate(condition_loss)