fix doc solver
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Module for GAROM"""
|
||||
"""Module for the GAROM solver."""
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
@@ -10,9 +10,9 @@ from ..loss import LossInterface, PowerLoss
|
||||
|
||||
class GAROM(MultiSolverInterface):
|
||||
"""
|
||||
GAROM solver class. This class implements Generative Adversarial
|
||||
Reduced Order Model solver, using user specified ``models`` to solve
|
||||
a specific order reduction``problem``.
|
||||
GAROM solver class. This class implements Generative Adversarial Reduced
|
||||
Order Model solver, using user specified ``models`` to solve a specific
|
||||
order reduction ``problem``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -39,40 +39,28 @@ class GAROM(MultiSolverInterface):
|
||||
regularizer=False,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module generator: The neural network model to use
|
||||
for the generator.
|
||||
:param torch.nn.Module discriminator: The neural network model to use
|
||||
Initialization of the :class:`GAROM` class.
|
||||
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module generator: The generator model.
|
||||
:param torch.nn.Module discriminator: The discriminator model.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
|
||||
:param Optimizer optimizer_generator: The optimizer for the generator.
|
||||
If `None`, the Adam optimizer is used. Default is ``None``.
|
||||
:param Optimizer optimizer_discriminator: The optimizer for the
|
||||
discriminator. If `None`, the Adam optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Scheduler scheduler_generator: The learning rate scheduler for
|
||||
the generator.
|
||||
:param Scheduler scheduler_discriminator: The learning rate scheduler
|
||||
for the discriminator.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default ``None``. If ``loss`` is ``None`` the defualt
|
||||
``PowerLoss(p=1)`` is used, as in the original paper.
|
||||
:param Optimizer optimizer_generator: The neural
|
||||
network optimizer to use for the generator network
|
||||
, default is `torch.optim.Adam`.
|
||||
:param Optimizer optimizer_discriminator: The neural
|
||||
network optimizer to use for the discriminator network
|
||||
, default is `torch.optim.Adam`.
|
||||
:param Scheduler scheduler_generator: Learning
|
||||
rate scheduler for the generator.
|
||||
:param Scheduler scheduler_discriminator: Learning
|
||||
rate scheduler for the discriminator.
|
||||
:param dict scheduler_discriminator_kwargs: LR scheduler constructor
|
||||
keyword args.
|
||||
:param gamma: Ratio of expected loss for generator and discriminator,
|
||||
defaults to 0.3.
|
||||
:type gamma: float
|
||||
:param lambda_k: Learning rate for control theory optimization,
|
||||
defaults to 0.001.
|
||||
:type lambda_k: float
|
||||
:param regularizer: Regularization term in the GAROM loss,
|
||||
defaults to False.
|
||||
:type regularizer: bool
|
||||
|
||||
.. warning::
|
||||
The algorithm works only for data-driven model. Hence in the
|
||||
``problem`` definition the codition must only contain ``input``
|
||||
(e.g. coefficient parameters, time parameters), and ``target``.
|
||||
:param float gamma: Ratio of expected loss for generator and
|
||||
discriminator. Default is ``0.3``.
|
||||
:param float lambda_k: Learning rate for control theory optimization.
|
||||
Default is ``0.001``.
|
||||
:param bool regularizer: If ``True``, uses a regularization term in the
|
||||
GAROM loss. Default is ``False``.
|
||||
"""
|
||||
|
||||
# set loss
|
||||
@@ -112,19 +100,15 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
def forward(self, x, mc_steps=20, variance=False):
|
||||
"""
|
||||
Forward step for GAROM solver
|
||||
Forward pass implementation.
|
||||
|
||||
:param x: The input tensor.
|
||||
:type x: torch.Tensor
|
||||
:param mc_steps: Number of montecarlo samples to approximate the
|
||||
expected value, defaults to 20.
|
||||
:type mc_steps: int
|
||||
:param variance: Returining also the sample variance of the solution,
|
||||
defaults to False.
|
||||
:type variance: bool
|
||||
:param torch.Tensor x: The input tensor.
|
||||
:param int mc_steps: Number of Montecarlo samples to approximate the
|
||||
expected value. Default is ``20``.
|
||||
:param bool variance: If ``True``, the method returns also the variance
|
||||
of the solution. Default is ``False``.
|
||||
:return: The expected value of the generator distribution. If
|
||||
``variance=True`` also the
|
||||
sample variance is returned.
|
||||
``variance=True``, the method returns also the variance.
|
||||
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
|
||||
@@ -142,13 +126,24 @@ class GAROM(MultiSolverInterface):
|
||||
return mean
|
||||
|
||||
def sample(self, x):
|
||||
"""TODO"""
|
||||
"""
|
||||
Sample from the generator distribution.
|
||||
|
||||
:param torch.Tensor x: The input tensor.
|
||||
:return: The generated sample.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# sampling
|
||||
return self.generator(x)
|
||||
|
||||
def _train_generator(self, parameters, snapshots):
|
||||
"""
|
||||
Private method to train the generator network.
|
||||
Train the generator model.
|
||||
|
||||
:param torch.Tensor parameters: The input tensor.
|
||||
:param torch.Tensor snapshots: The target tensor.
|
||||
:return: The residual loss and the generator loss.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
optimizer = self.optimizer_generator
|
||||
optimizer.zero_grad()
|
||||
@@ -170,16 +165,13 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch, and ovverides
|
||||
the PytorchLightining implementation for logging the checkpoints.
|
||||
This method is called at the end of each training batch and overrides
|
||||
the PyTorch Lightning implementation to log checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The output from the model for the
|
||||
current batch.
|
||||
:param tuple batch: The current batch of data.
|
||||
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||
batch.
|
||||
:param dict batch: The current batch of data.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
:return: Whatever is returned by the parent
|
||||
method ``on_train_batch_end``.
|
||||
:rtype: Any
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
@@ -190,7 +182,12 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
def _train_discriminator(self, parameters, snapshots):
|
||||
"""
|
||||
Private method to train the discriminator network.
|
||||
Train the discriminator model.
|
||||
|
||||
:param torch.Tensor parameters: The input tensor.
|
||||
:param torch.Tensor snapshots: The target tensor.
|
||||
:return: The residual loss and the generator loss.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
optimizer = self.optimizer_discriminator
|
||||
optimizer.zero_grad()
|
||||
@@ -215,8 +212,15 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
def _update_weights(self, d_loss_real, d_loss_fake):
|
||||
"""
|
||||
Private method to Update the weights of the generator and discriminator
|
||||
networks.
|
||||
Update the weights of the generator and discriminator models.
|
||||
|
||||
:param torch.Tensor d_loss_real: The discriminator loss computed on
|
||||
dataset samples.
|
||||
:param torch.Tensor d_loss_fake: The discriminator loss computed on
|
||||
generated samples.
|
||||
:return: The difference between the loss computed on the dataset samples
|
||||
and the loss computed on the generated samples.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||
@@ -227,11 +231,11 @@ class GAROM(MultiSolverInterface):
|
||||
return diff
|
||||
|
||||
def optimization_cycle(self, batch):
|
||||
"""GAROM solver training step.
|
||||
"""
|
||||
The optimization cycle for the GAROM solver.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:return: The sum of the loss functions.
|
||||
:param tuple batch: The batch element in the dataloader.
|
||||
:return: The loss of the optimization cycle.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
condition_loss = {}
|
||||
@@ -258,6 +262,13 @@ class GAROM(MultiSolverInterface):
|
||||
return condition_loss
|
||||
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
The validation step for the PINN solver.
|
||||
|
||||
:param dict batch: The batch of data to use in the validation step.
|
||||
:return: The loss of the validation step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = (
|
||||
@@ -273,6 +284,13 @@ class GAROM(MultiSolverInterface):
|
||||
return loss
|
||||
|
||||
def test_step(self, batch):
|
||||
"""
|
||||
The test step for the PINN solver.
|
||||
|
||||
:param dict batch: The batch of data to use in the test step.
|
||||
:return: The loss of the test step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = (
|
||||
@@ -289,30 +307,60 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
"""TODO"""
|
||||
"""
|
||||
The generator model.
|
||||
|
||||
:return: The generator model.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
"""TODO"""
|
||||
"""
|
||||
The discriminator model.
|
||||
|
||||
:return: The discriminator model.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self.models[1]
|
||||
|
||||
@property
|
||||
def optimizer_generator(self):
|
||||
"""TODO"""
|
||||
"""
|
||||
The optimizer for the generator.
|
||||
|
||||
:return: The optimizer for the generator.
|
||||
:rtype: Optimizer
|
||||
"""
|
||||
return self.optimizers[0].instance
|
||||
|
||||
@property
|
||||
def optimizer_discriminator(self):
|
||||
"""TODO"""
|
||||
"""
|
||||
The optimizer for the discriminator.
|
||||
|
||||
:return: The optimizer for the discriminator.
|
||||
:rtype: Optimizer
|
||||
"""
|
||||
return self.optimizers[1].instance
|
||||
|
||||
@property
|
||||
def scheduler_generator(self):
|
||||
"""TODO"""
|
||||
"""
|
||||
The scheduler for the generator.
|
||||
|
||||
:return: The scheduler for the generator.
|
||||
:rtype: Scheduler
|
||||
"""
|
||||
return self.schedulers[0].instance
|
||||
|
||||
@property
|
||||
def scheduler_discriminator(self):
|
||||
"""TODO"""
|
||||
"""
|
||||
The scheduler for the discriminator.
|
||||
|
||||
:return: The scheduler for the discriminator.
|
||||
:rtype: Scheduler
|
||||
"""
|
||||
return self.schedulers[1].instance
|
||||
|
||||
Reference in New Issue
Block a user