🎨 Format Python code with psf/black
This commit is contained in:
@@ -2,10 +2,13 @@
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
|
||||
from torch.optim.lr_scheduler import (
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
from .solver import SolverInterface
|
||||
@@ -18,12 +21,12 @@ class GAROM(SolverInterface):
|
||||
"""
|
||||
GAROM solver class. This class implements Generative Adversarial
|
||||
Reduced Order Model solver, using user specified ``models`` to solve
|
||||
a specific order reduction``problem``.
|
||||
a specific order reduction``problem``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Coscia, D., Demo, N., & Rozza, G. (2023).
|
||||
*Generative Adversarial Reduced Order Modelling*.
|
||||
*Generative Adversarial Reduced Order Modelling*.
|
||||
DOI: `arXiv preprint arXiv:2305.15881.
|
||||
<https://doi.org/10.48550/arXiv.2305.15881>`_.
|
||||
"""
|
||||
@@ -35,19 +38,13 @@ class GAROM(SolverInterface):
|
||||
discriminator,
|
||||
loss=None,
|
||||
optimizer_generator=torch.optim.Adam,
|
||||
optimizer_generator_kwargs={'lr': 0.001},
|
||||
optimizer_generator_kwargs={"lr": 0.001},
|
||||
optimizer_discriminator=torch.optim.Adam,
|
||||
optimizer_discriminator_kwargs={'lr': 0.001},
|
||||
optimizer_discriminator_kwargs={"lr": 0.001},
|
||||
scheduler_generator=ConstantLR,
|
||||
scheduler_generator_kwargs={
|
||||
"factor": 1,
|
||||
"total_iters": 0
|
||||
},
|
||||
scheduler_generator_kwargs={"factor": 1, "total_iters": 0},
|
||||
scheduler_discriminator=ConstantLR,
|
||||
scheduler_discriminator_kwargs={
|
||||
"factor": 1,
|
||||
"total_iters": 0
|
||||
},
|
||||
scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0},
|
||||
gamma=0.3,
|
||||
lambda_k=0.001,
|
||||
regularizer=False,
|
||||
@@ -95,8 +92,10 @@ class GAROM(SolverInterface):
|
||||
problem=problem,
|
||||
optimizers=[optimizer_generator, optimizer_discriminator],
|
||||
optimizers_kwargs=[
|
||||
optimizer_generator_kwargs, optimizer_discriminator_kwargs
|
||||
])
|
||||
optimizer_generator_kwargs,
|
||||
optimizer_discriminator_kwargs,
|
||||
],
|
||||
)
|
||||
|
||||
# set automatic optimization for GANs
|
||||
self.automatic_optimization = False
|
||||
@@ -118,13 +117,14 @@ class GAROM(SolverInterface):
|
||||
# assign schedulers
|
||||
self._schedulers = [
|
||||
scheduler_generator(
|
||||
self.optimizers[0], **scheduler_generator_kwargs),
|
||||
self.optimizers[0], **scheduler_generator_kwargs
|
||||
),
|
||||
scheduler_discriminator(
|
||||
self.optimizers[1],
|
||||
**scheduler_discriminator_kwargs)
|
||||
self.optimizers[1], **scheduler_discriminator_kwargs
|
||||
),
|
||||
]
|
||||
|
||||
# loss and writer
|
||||
# loss and writer
|
||||
self._loss = loss
|
||||
|
||||
# began hyperparameters
|
||||
@@ -141,7 +141,7 @@ class GAROM(SolverInterface):
|
||||
|
||||
:param x: The input tensor.
|
||||
:type x: torch.Tensor
|
||||
:param mc_steps: Number of montecarlo samples to approximate the
|
||||
:param mc_steps: Number of montecarlo samples to approximate the
|
||||
expected value, defaults to 20.
|
||||
:type mc_steps: int
|
||||
:param variance: Returining also the sample variance of the solution, defaults to False.
|
||||
@@ -189,8 +189,12 @@ class GAROM(SolverInterface):
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(snapshots, generated_snapshots)
|
||||
d_fake = self.discriminator.forward_map([generated_snapshots, parameters])
|
||||
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||
d_fake = self.discriminator.forward_map(
|
||||
[generated_snapshots, parameters]
|
||||
)
|
||||
g_loss = (
|
||||
self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||
)
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
@@ -210,7 +214,9 @@ class GAROM(SolverInterface):
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator.forward_map([snapshots, parameters])
|
||||
d_fake = self.discriminator.forward_map([generated_snapshots, parameters])
|
||||
d_fake = self.discriminator.forward_map(
|
||||
[generated_snapshots, parameters]
|
||||
)
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, snapshots)
|
||||
@@ -235,7 +241,7 @@ class GAROM(SolverInterface):
|
||||
self.k += self.lambda_k * diff.item()
|
||||
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
|
||||
return diff
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""GAROM solver training step.
|
||||
|
||||
@@ -248,42 +254,75 @@ class GAROM(SolverInterface):
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_idx = batch['condition']
|
||||
condition_idx = batch["condition"]
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
else:
|
||||
condition_name = dataloader.loaders.condition_names[condition_id]
|
||||
condition_name = dataloader.loaders.condition_names[
|
||||
condition_id
|
||||
]
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts'].detach()
|
||||
out = batch['output']
|
||||
pts = batch["pts"].detach()
|
||||
out = batch["output"]
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
raise RuntimeError("Something wrong happened.")
|
||||
|
||||
# for data driven mode
|
||||
if not hasattr(condition, 'output_points'):
|
||||
raise NotImplementedError('GAROM works only in data-driven mode.')
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
"GAROM works only in data-driven mode."
|
||||
)
|
||||
|
||||
# get data
|
||||
snapshots = out[condition_idx == condition_id]
|
||||
parameters = pts[condition_idx == condition_id]
|
||||
|
||||
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
|
||||
parameters, snapshots)
|
||||
parameters, snapshots
|
||||
)
|
||||
|
||||
r_loss, g_loss = self._train_generator(parameters, snapshots)
|
||||
|
||||
|
||||
diff = self._update_weights(d_loss_real, d_loss_fake)
|
||||
|
||||
# logging
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
self.log(
|
||||
"mean_loss",
|
||||
float(r_loss),
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
)
|
||||
self.log(
|
||||
"d_loss",
|
||||
float(d_loss),
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
)
|
||||
self.log(
|
||||
"g_loss",
|
||||
float(g_loss),
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
)
|
||||
self.log(
|
||||
"stability_metric",
|
||||
float(d_loss_real + torch.abs(diff)),
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user