fix doc solver

This commit is contained in:
giovanni
2025-03-13 19:05:15 +01:00
committed by FilippoOlivo
parent fd2b50fc06
commit 5e6aa61592
6 changed files with 374 additions and 251 deletions

View File

@@ -1,4 +1,4 @@
"""Module for GAROM"""
"""Module for the GAROM solver."""
import torch
from torch.nn.modules.loss import _Loss
@@ -10,9 +10,9 @@ from ..loss import LossInterface, PowerLoss
class GAROM(MultiSolverInterface):
"""
GAROM solver class. This class implements Generative Adversarial
Reduced Order Model solver, using user specified ``models`` to solve
a specific order reduction``problem``.
GAROM solver class. This class implements Generative Adversarial Reduced
Order Model solver, using user specified ``models`` to solve a specific
order reduction ``problem``.
.. seealso::
@@ -39,40 +39,28 @@ class GAROM(MultiSolverInterface):
regularizer=False,
):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module generator: The neural network model to use
for the generator.
:param torch.nn.Module discriminator: The neural network model to use
Initialization of the :class:`GAROM` class.
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module generator: The generator model.
:param torch.nn.Module discriminator: The discriminator model.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
:param Optimizer optimizer_generator: The optimizer for the generator.
If `None`, the Adam optimizer is used. Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
discriminator. If `None`, the Adam optimizer is used.
Default is ``None``.
:param Scheduler scheduler_generator: The learning rate scheduler for
the generator.
:param Scheduler scheduler_discriminator: The learning rate scheduler
for the discriminator.
:param torch.nn.Module loss: The loss function used as minimizer,
default ``None``. If ``loss`` is ``None`` the defualt
``PowerLoss(p=1)`` is used, as in the original paper.
:param Optimizer optimizer_generator: The neural
network optimizer to use for the generator network
, default is `torch.optim.Adam`.
:param Optimizer optimizer_discriminator: The neural
network optimizer to use for the discriminator network
, default is `torch.optim.Adam`.
:param Scheduler scheduler_generator: Learning
rate scheduler for the generator.
:param Scheduler scheduler_discriminator: Learning
rate scheduler for the discriminator.
:param dict scheduler_discriminator_kwargs: LR scheduler constructor
keyword args.
:param gamma: Ratio of expected loss for generator and discriminator,
defaults to 0.3.
:type gamma: float
:param lambda_k: Learning rate for control theory optimization,
defaults to 0.001.
:type lambda_k: float
:param regularizer: Regularization term in the GAROM loss,
defaults to False.
:type regularizer: bool
.. warning::
The algorithm works only for data-driven model. Hence in the
``problem`` definition the codition must only contain ``input``
(e.g. coefficient parameters, time parameters), and ``target``.
:param float gamma: Ratio of expected loss for generator and
discriminator. Default is ``0.3``.
:param float lambda_k: Learning rate for control theory optimization.
Default is ``0.001``.
:param bool regularizer: If ``True``, uses a regularization term in the
GAROM loss. Default is ``False``.
"""
# set loss
@@ -112,19 +100,15 @@ class GAROM(MultiSolverInterface):
def forward(self, x, mc_steps=20, variance=False):
"""
Forward step for GAROM solver
Forward pass implementation.
:param x: The input tensor.
:type x: torch.Tensor
:param mc_steps: Number of montecarlo samples to approximate the
expected value, defaults to 20.
:type mc_steps: int
:param variance: Returining also the sample variance of the solution,
defaults to False.
:type variance: bool
:param torch.Tensor x: The input tensor.
:param int mc_steps: Number of Montecarlo samples to approximate the
expected value. Default is ``20``.
:param bool variance: If ``True``, the method returns also the variance
of the solution. Default is ``False``.
:return: The expected value of the generator distribution. If
``variance=True`` also the
sample variance is returned.
``variance=True``, the method returns also the variance.
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
"""
@@ -142,13 +126,24 @@ class GAROM(MultiSolverInterface):
return mean
def sample(self, x):
"""TODO"""
"""
Sample from the generator distribution.
:param torch.Tensor x: The input tensor.
:return: The generated sample.
:rtype: torch.Tensor
"""
# sampling
return self.generator(x)
def _train_generator(self, parameters, snapshots):
"""
Private method to train the generator network.
Train the generator model.
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
optimizer = self.optimizer_generator
optimizer.zero_grad()
@@ -170,16 +165,13 @@ class GAROM(MultiSolverInterface):
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch, and ovverides
the PytorchLightining implementation for logging the checkpoints.
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The output from the model for the
current batch.
:param tuple batch: The current batch of data.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param dict batch: The current batch of data.
:param int batch_idx: The index of the current batch.
:return: Whatever is returned by the parent
method ``on_train_batch_end``.
:rtype: Any
"""
# increase by one the counter of optimization to save loggers
(
@@ -190,7 +182,12 @@ class GAROM(MultiSolverInterface):
def _train_discriminator(self, parameters, snapshots):
"""
Private method to train the discriminator network.
Train the discriminator model.
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
optimizer = self.optimizer_discriminator
optimizer.zero_grad()
@@ -215,8 +212,15 @@ class GAROM(MultiSolverInterface):
def _update_weights(self, d_loss_real, d_loss_fake):
"""
Private method to Update the weights of the generator and discriminator
networks.
Update the weights of the generator and discriminator models.
:param torch.Tensor d_loss_real: The discriminator loss computed on
dataset samples.
:param torch.Tensor d_loss_fake: The discriminator loss computed on
generated samples.
:return: The difference between the loss computed on the dataset samples
and the loss computed on the generated samples.
:rtype: torch.Tensor
"""
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
@@ -227,11 +231,11 @@ class GAROM(MultiSolverInterface):
return diff
def optimization_cycle(self, batch):
"""GAROM solver training step.
"""
The optimization cycle for the GAROM solver.
:param batch: The batch element in the dataloader.
:type batch: tuple
:return: The sum of the loss functions.
:param tuple batch: The batch element in the dataloader.
:return: The loss of the optimization cycle.
:rtype: LabelTensor
"""
condition_loss = {}
@@ -258,6 +262,13 @@ class GAROM(MultiSolverInterface):
return condition_loss
def validation_step(self, batch):
"""
The validation step for the PINN solver.
:param dict batch: The batch of data to use in the validation step.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
@@ -273,6 +284,13 @@ class GAROM(MultiSolverInterface):
return loss
def test_step(self, batch):
"""
The test step for the PINN solver.
:param dict batch: The batch of data to use in the test step.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
@@ -289,30 +307,60 @@ class GAROM(MultiSolverInterface):
@property
def generator(self):
"""TODO"""
"""
The generator model.
:return: The generator model.
:rtype: torch.nn.Module
"""
return self.models[0]
@property
def discriminator(self):
"""TODO"""
"""
The discriminator model.
:return: The discriminator model.
:rtype: torch.nn.Module
"""
return self.models[1]
@property
def optimizer_generator(self):
"""TODO"""
"""
The optimizer for the generator.
:return: The optimizer for the generator.
:rtype: Optimizer
"""
return self.optimizers[0].instance
@property
def optimizer_discriminator(self):
"""TODO"""
"""
The optimizer for the discriminator.
:return: The optimizer for the discriminator.
:rtype: Optimizer
"""
return self.optimizers[1].instance
@property
def scheduler_generator(self):
"""TODO"""
"""
The scheduler for the generator.
:return: The scheduler for the generator.
:rtype: Scheduler
"""
return self.schedulers[0].instance
@property
def scheduler_discriminator(self):
"""TODO"""
"""
The scheduler for the discriminator.
:return: The scheduler for the discriminator.
:rtype: Scheduler
"""
return self.schedulers[1].instance