Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -1,4 +1,4 @@
""" Module for GAROM """
"""Module for GAROM"""
import torch
@@ -86,13 +86,14 @@ class GAROM(MultiSolverInterface):
scheduler_generator,
scheduler_discriminator,
],
use_lt=False
use_lt=False,
)
# check consistency
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
subclass=False)
self._loss = loss
check_consistency(
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
)
self._loss = loss
# set automatic optimization for GANs
self.automatic_optimization = False
@@ -152,9 +153,7 @@ class GAROM(MultiSolverInterface):
# generator loss
r_loss = self._loss(snapshots, generated_snapshots)
d_fake = self.discriminator(
[generated_snapshots, parameters]
)
d_fake = self.discriminator([generated_snapshots, parameters])
g_loss = (
self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
)
@@ -180,8 +179,7 @@ class GAROM(MultiSolverInterface):
"""
# increase by one the counter of optimization to save loggers
(
self.trainer.fit_loop.epoch_loop.manual_optimization
.optim_step_progress.total.completed
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
@@ -198,9 +196,7 @@ class GAROM(MultiSolverInterface):
# Discriminator pass
d_real = self.discriminator([snapshots, parameters])
d_fake = self.discriminator(
[generated_snapshots, parameters]
)
d_fake = self.discriminator([generated_snapshots, parameters])
# evaluate loss
d_loss_real = self._loss(d_real, snapshots)
@@ -236,7 +232,10 @@ class GAROM(MultiSolverInterface):
"""
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = points['input_points'], points['output_points']
parameters, snapshots = (
points["input_points"],
points["output_points"],
)
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
parameters, snapshots
)
@@ -245,51 +244,53 @@ class GAROM(MultiSolverInterface):
condition_loss[condition_name] = r_loss
# some extra logging
self.store_log(
"d_loss",
float(d_loss),
self.get_batch_size(batch)
)
self.store_log(
"g_loss",
float(g_loss),
self.get_batch_size(batch)
)
self.store_log("d_loss", float(d_loss), self.get_batch_size(batch))
self.store_log("g_loss", float(g_loss), self.get_batch_size(batch))
self.store_log(
"stability_metric",
float(d_loss_real + torch.abs(diff)),
self.get_batch_size(batch)
self.get_batch_size(batch),
)
return condition_loss
def validation_step(self, batch):
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = points['input_points'], points['output_points']
parameters, snapshots = (
points["input_points"],
points["output_points"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(snapshots, snapshots_gen)
condition_loss[condition_name] = self._loss(
snapshots, snapshots_gen
)
loss = self.weighting.aggregate(condition_loss)
self.store_log('val_loss', loss, self.get_batch_size(batch))
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss
def test_step(self, batch):
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = points['input_points'], points['output_points']
parameters, snapshots = (
points["input_points"],
points["output_points"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(snapshots, snapshots_gen)
condition_loss[condition_name] = self._loss(
snapshots, snapshots_gen
)
loss = self.weighting.aggregate(condition_loss)
self.store_log('test_loss', loss, self.get_batch_size(batch))
self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss
@property
def generator(self):
return self.models[0]
@property
def discriminator(self):
return self.models[1]
@property
def optimizer_generator(self):
return self.optimizers[0].instance