Refactoring solvers (#541)
* Refactoring solvers * Simplify logic compile * Improve and update doc * Create SupervisedSolverInterface * Specialize SupervisedSolver and ReducedOrderModelSolver * Create EnsembleSolverInterface + EnsembleSupervisedSolver * Create tests ensemble solvers * formatter * codacy * fix issues + speedup test
This commit is contained in:
@@ -48,18 +48,18 @@ class GAROM(MultiSolverInterface):
|
||||
If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1``
|
||||
is used. Default is ``None``.
|
||||
:param Optimizer optimizer_generator: The optimizer for the generator.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Optimizer optimizer_discriminator: The optimizer for the
|
||||
discriminator. If `None`, the :class:`torch.optim.Adam` optimizer is
|
||||
used. Default is ``None``.
|
||||
discriminator. If ``None``, the :class:`torch.optim.Adam`
|
||||
optimizer is used. Default is ``None``.
|
||||
:param Scheduler scheduler_generator: The learning rate scheduler for
|
||||
the generator.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param Scheduler scheduler_discriminator: The learning rate scheduler
|
||||
for the discriminator.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param float gamma: Ratio of expected loss for generator and
|
||||
discriminator. Default is ``0.3``.
|
||||
@@ -88,7 +88,7 @@ class GAROM(MultiSolverInterface):
|
||||
check_consistency(
|
||||
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
|
||||
)
|
||||
self._loss = loss
|
||||
self._loss_fn = loss
|
||||
|
||||
# set automatic optimization for GANs
|
||||
self.automatic_optimization = False
|
||||
@@ -157,10 +157,11 @@ class GAROM(MultiSolverInterface):
|
||||
generated_snapshots = self.sample(parameters)
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(snapshots, generated_snapshots)
|
||||
r_loss = self._loss_fn(snapshots, generated_snapshots)
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
g_loss = (
|
||||
self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||
self._loss_fn(d_fake, generated_snapshots)
|
||||
+ self.regularizer * r_loss
|
||||
)
|
||||
|
||||
# backward step
|
||||
@@ -170,24 +171,6 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
return r_loss, g_loss
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch and overrides
|
||||
the PyTorch Lightning implementation to log checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||
batch.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def _train_discriminator(self, parameters, snapshots):
|
||||
"""
|
||||
Train the discriminator model.
|
||||
@@ -207,8 +190,8 @@ class GAROM(MultiSolverInterface):
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, snapshots)
|
||||
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
|
||||
d_loss_real = self._loss_fn(d_real, snapshots)
|
||||
d_loss_fake = self._loss_fn(d_fake, generated_snapshots.detach())
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
@@ -288,7 +271,7 @@ class GAROM(MultiSolverInterface):
|
||||
points["target"],
|
||||
)
|
||||
snapshots_gen = self.generator(parameters)
|
||||
condition_loss[condition_name] = self._loss(
|
||||
condition_loss[condition_name] = self._loss_fn(
|
||||
snapshots, snapshots_gen
|
||||
)
|
||||
loss = self.weighting.aggregate(condition_loss)
|
||||
@@ -311,7 +294,7 @@ class GAROM(MultiSolverInterface):
|
||||
points["target"],
|
||||
)
|
||||
snapshots_gen = self.generator(parameters)
|
||||
condition_loss[condition_name] = self._loss(
|
||||
condition_loss[condition_name] = self._loss_fn(
|
||||
snapshots, snapshots_gen
|
||||
)
|
||||
loss = self.weighting.aggregate(condition_loss)
|
||||
|
||||
Reference in New Issue
Block a user