Documentation for v0.1 version (#199)
* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
3f9305d475
commit
8b7b61b3bd
@@ -4,7 +4,7 @@ import torch
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
from .solver import SolverInterface
|
||||
@@ -22,28 +22,36 @@ class GAROM(SolverInterface):
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Coscia, D., Demo, N., & Rozza, G. (2023).
|
||||
Generative Adversarial Reduced Order Modelling.
|
||||
arXiv preprint arXiv:2305.15881.
|
||||
*Generative Adversarial Reduced Order Modelling*.
|
||||
DOI: `arXiv preprint arXiv:2305.15881.
|
||||
<https://doi.org/10.48550/arXiv.2305.15881>`_.
|
||||
"""
|
||||
def __init__(self,
|
||||
problem,
|
||||
generator,
|
||||
discriminator,
|
||||
extra_features=None,
|
||||
loss = None,
|
||||
optimizer_generator=torch.optim.Adam,
|
||||
optimizer_generator_kwargs={'lr' : 0.001},
|
||||
optimizer_discriminator=torch.optim.Adam,
|
||||
optimizer_discriminator_kwargs={'lr' : 0.001},
|
||||
scheduler_generator=ConstantLR,
|
||||
scheduler_generator_kwargs={"factor": 1, "total_iters": 0},
|
||||
scheduler_discriminator=ConstantLR,
|
||||
scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0},
|
||||
gamma = 0.3,
|
||||
lambda_k = 0.001,
|
||||
regularizer = False,
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
generator,
|
||||
discriminator,
|
||||
extra_features=None,
|
||||
loss=None,
|
||||
optimizer_generator=torch.optim.Adam,
|
||||
optimizer_generator_kwargs={'lr': 0.001},
|
||||
optimizer_discriminator=torch.optim.Adam,
|
||||
optimizer_discriminator_kwargs={'lr': 0.001},
|
||||
scheduler_generator=ConstantLR,
|
||||
scheduler_generator_kwargs={
|
||||
"factor": 1,
|
||||
"total_iters": 0
|
||||
},
|
||||
scheduler_discriminator=ConstantLR,
|
||||
scheduler_discriminator_kwargs={
|
||||
"factor": 1,
|
||||
"total_iters": 0
|
||||
},
|
||||
gamma=0.3,
|
||||
lambda_k=0.001,
|
||||
regularizer=False,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module generator: The neural network model to use
|
||||
@@ -77,11 +85,11 @@ class GAROM(SolverInterface):
|
||||
rate scheduler for the discriminator.
|
||||
:param dict scheduler_discriminator_kwargs: LR scheduler constructor keyword args.
|
||||
:param gamma: Ratio of expected loss for generator and discriminator, defaults to 0.3.
|
||||
:type gamma: float, optional
|
||||
:type gamma: float
|
||||
:param lambda_k: Learning rate for control theory optimization, defaults to 0.001.
|
||||
:type lambda_k: float, optional
|
||||
:type lambda_k: float
|
||||
:param regularizer: Regularization term in the GAROM loss, defaults to False.
|
||||
:type regularizer: bool, optional
|
||||
:type regularizer: bool
|
||||
|
||||
.. warning::
|
||||
The algorithm works only for data-driven model. Hence in the ``problem`` definition
|
||||
@@ -90,22 +98,27 @@ class GAROM(SolverInterface):
|
||||
"""
|
||||
|
||||
if isinstance(extra_features, dict):
|
||||
extra_features = [extra_features['generator'], extra_features['discriminator']]
|
||||
extra_features = [
|
||||
extra_features['generator'], extra_features['discriminator']
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
models=[generator, discriminator],
|
||||
problem=problem,
|
||||
extra_features=extra_features,
|
||||
optimizers=[optimizer_generator, optimizer_discriminator],
|
||||
optimizers_kwargs=[
|
||||
optimizer_generator_kwargs, optimizer_discriminator_kwargs
|
||||
])
|
||||
|
||||
super().__init__(models=[generator, discriminator],
|
||||
problem=problem,
|
||||
extra_features=extra_features,
|
||||
optimizers=[optimizer_generator, optimizer_discriminator],
|
||||
optimizers_kwargs=[optimizer_generator_kwargs, optimizer_discriminator_kwargs])
|
||||
|
||||
# set automatic optimization for GANs
|
||||
self.automatic_optimization = False
|
||||
|
||||
# set loss
|
||||
if loss is None:
|
||||
loss = PowerLoss(p=1)
|
||||
|
||||
# check consistency
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler_generator, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_generator_kwargs, dict)
|
||||
check_consistency(scheduler_discriminator, LRScheduler, subclass=True)
|
||||
@@ -134,6 +147,20 @@ class GAROM(SolverInterface):
|
||||
self.regularizer = float(regularizer)
|
||||
|
||||
def forward(self, x, mc_steps=20, variance=False):
|
||||
"""
|
||||
Forward step for GAROM solver
|
||||
|
||||
:param x: The input tensor.
|
||||
:type x: torch.Tensor
|
||||
:param mc_steps: Number of montecarlo samples to approximate the
|
||||
expected value, defaults to 20.
|
||||
:type mc_steps: int
|
||||
:param variance: Returining also the sample variance of the solution, defaults to False.
|
||||
:type variance: bool
|
||||
:return: The expected value of the generator distribution. If ``variance=True`` also the
|
||||
sample variance is returned.
|
||||
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
|
||||
# sampling
|
||||
field_sample = [self.sample(x) for _ in range(mc_steps)]
|
||||
@@ -147,10 +174,11 @@ class GAROM(SolverInterface):
|
||||
return mean, var
|
||||
|
||||
return mean
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the GAROM
|
||||
solver.
|
||||
"""
|
||||
Optimizer configuration for the GAROM
|
||||
solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
@@ -220,7 +248,7 @@ class GAROM(SolverInterface):
|
||||
return diff
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
"""GAROM solver training step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
@@ -265,27 +293,27 @@ class GAROM(SolverInterface):
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return self.models[0]
|
||||
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
return self.models[1]
|
||||
|
||||
|
||||
@property
|
||||
def optimizer_generator(self):
|
||||
return self.optimizers[0]
|
||||
|
||||
|
||||
@property
|
||||
def optimizer_discriminator(self):
|
||||
return self.optimizers[1]
|
||||
|
||||
|
||||
@property
|
||||
def scheduler_generator(self):
|
||||
return self._schedulers[0]
|
||||
|
||||
|
||||
@property
|
||||
def scheduler_discriminator(self):
|
||||
return self._schedulers[1]
|
||||
|
||||
Reference in New Issue
Block a user