Refactoring solvers (#541)
* Refactoring solvers * Simplify logic compile * Improve and update doc * Create SupervisedSolverInterface * Specialize SupervisedSolver and ReducedOrderModelSolver * Create EnsembleSolverInterface + EnsembleSupervisedSolver * Create tests ensemble solvers * formatter * codacy * fix issues + speedup test
This commit is contained in:
committed by
FilippoOlivo
parent
fa6fda0bd5
commit
1bb3c125ac
@@ -69,26 +69,26 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param torch.nn.Module discriminator: The discriminator to be used.
|
||||
If `None`, the discriminator is a deepcopy of the ``model``.
|
||||
If ``None``, the discriminator is a deepcopy of the ``model``.
|
||||
Default is ``None``.
|
||||
:param torch.optim.Optimizer optimizer_model: The optimizer of the
|
||||
``model``. If `None`, the :class:`torch.optim.Adam` optimizer is
|
||||
``model``. If ``None``, the :class:`torch.optim.Adam` optimizer is
|
||||
used. Default is ``None``.
|
||||
:param torch.optim.Optimizer optimizer_discriminator: The optimizer of
|
||||
the ``discriminator``. If `None`, the :class:`torch.optim.Adam`
|
||||
the ``discriminator``. If ``None``, the :class:`torch.optim.Adam`
|
||||
optimizer is used. Default is ``None``.
|
||||
:param Scheduler scheduler_model: Learning rate scheduler for the
|
||||
``model``.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param Scheduler scheduler_discriminator: Learning rate scheduler for
|
||||
the ``discriminator``.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
"""
|
||||
if discriminator is None:
|
||||
@@ -156,12 +156,27 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
residual = residual * discriminator_bets
|
||||
|
||||
# Compute competitive residual.
|
||||
loss_val = self.loss(
|
||||
loss_val = self._loss_fn(
|
||||
torch.zeros_like(residual, requires_grad=True),
|
||||
residual,
|
||||
)
|
||||
return loss_val
|
||||
|
||||
def loss_data(self, input, target):
|
||||
"""
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
between the network's output and the true solution. This method should
|
||||
not be overridden, if not intentionally.
|
||||
|
||||
:param input: The input to the neural network.
|
||||
:type input: LabelTensor
|
||||
:param target: The target to compare with the network's output.
|
||||
:type target: LabelTensor
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self._loss_fn(self.forward(input), target)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration.
|
||||
@@ -195,24 +210,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
],
|
||||
)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch and overrides
|
||||
the PyTorch Lightning implementation to log checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||
batch.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user