Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Module for GAROM """
|
||||
"""Module for GAROM"""
|
||||
|
||||
import torch
|
||||
|
||||
@@ -86,13 +86,14 @@ class GAROM(MultiSolverInterface):
|
||||
scheduler_generator,
|
||||
scheduler_discriminator,
|
||||
],
|
||||
use_lt=False
|
||||
use_lt=False,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
|
||||
subclass=False)
|
||||
self._loss = loss
|
||||
check_consistency(
|
||||
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
|
||||
)
|
||||
self._loss = loss
|
||||
|
||||
# set automatic optimization for GANs
|
||||
self.automatic_optimization = False
|
||||
@@ -152,9 +153,7 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(snapshots, generated_snapshots)
|
||||
d_fake = self.discriminator(
|
||||
[generated_snapshots, parameters]
|
||||
)
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
g_loss = (
|
||||
self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||
)
|
||||
@@ -180,8 +179,7 @@ class GAROM(MultiSolverInterface):
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization
|
||||
.optim_step_progress.total.completed
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
@@ -198,9 +196,7 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([snapshots, parameters])
|
||||
d_fake = self.discriminator(
|
||||
[generated_snapshots, parameters]
|
||||
)
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, snapshots)
|
||||
@@ -236,7 +232,10 @@ class GAROM(MultiSolverInterface):
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = points['input_points'], points['output_points']
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
)
|
||||
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
|
||||
parameters, snapshots
|
||||
)
|
||||
@@ -245,51 +244,53 @@ class GAROM(MultiSolverInterface):
|
||||
condition_loss[condition_name] = r_loss
|
||||
|
||||
# some extra logging
|
||||
self.store_log(
|
||||
"d_loss",
|
||||
float(d_loss),
|
||||
self.get_batch_size(batch)
|
||||
)
|
||||
self.store_log(
|
||||
"g_loss",
|
||||
float(g_loss),
|
||||
self.get_batch_size(batch)
|
||||
)
|
||||
self.store_log("d_loss", float(d_loss), self.get_batch_size(batch))
|
||||
self.store_log("g_loss", float(g_loss), self.get_batch_size(batch))
|
||||
self.store_log(
|
||||
"stability_metric",
|
||||
float(d_loss_real + torch.abs(diff)),
|
||||
self.get_batch_size(batch)
|
||||
self.get_batch_size(batch),
|
||||
)
|
||||
return condition_loss
|
||||
|
||||
def validation_step(self, batch):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = points['input_points'], points['output_points']
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
)
|
||||
snapshots_gen = self.generator(parameters)
|
||||
condition_loss[condition_name] = self._loss(snapshots, snapshots_gen)
|
||||
condition_loss[condition_name] = self._loss(
|
||||
snapshots, snapshots_gen
|
||||
)
|
||||
loss = self.weighting.aggregate(condition_loss)
|
||||
self.store_log('val_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
|
||||
def test_step(self, batch):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = points['input_points'], points['output_points']
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
)
|
||||
snapshots_gen = self.generator(parameters)
|
||||
condition_loss[condition_name] = self._loss(snapshots, snapshots_gen)
|
||||
condition_loss[condition_name] = self._loss(
|
||||
snapshots, snapshots_gen
|
||||
)
|
||||
loss = self.weighting.aggregate(condition_loss)
|
||||
self.store_log('test_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return self.models[0]
|
||||
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
return self.models[1]
|
||||
|
||||
|
||||
@property
|
||||
def optimizer_generator(self):
|
||||
return self.optimizers[0].instance
|
||||
|
||||
Reference in New Issue
Block a user