minor updates in garom.py

* Removing `retain_graph` in backward for discriminator
* Fixing issues with different precision training for Lightining
This commit is contained in:
Dario Coscia
2023-11-17 15:44:47 +01:00
committed by Nicola Demo
parent 48b2e339b5
commit 56bd0dac7d

View File

@@ -131,6 +131,8 @@ class GAROM(SolverInterface):
self.gamma = gamma
self.lambda_k = lambda_k
self.regularizer = float(regularizer)
self._generator = self.models[0]
self._discriminator = self.models[1]
def forward(self, x, mc_steps=20, variance=False):
"""
@@ -215,7 +217,7 @@ class GAROM(SolverInterface):
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
d_loss.backward(retain_graph=True)
d_loss.backward()
optimizer.step()
return d_loss_real, d_loss_fake, d_loss
@@ -251,7 +253,7 @@ class GAROM(SolverInterface):
condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch['pts']
pts = batch['pts'].detach()
out = batch['output']
if condition_name not in self.problem.conditions:
@@ -282,11 +284,11 @@ class GAROM(SolverInterface):
@property
def generator(self):
return self.models[0]
return self._generator
@property
def discriminator(self):
return self.models[1]
return self._discriminator
@property
def optimizer_generator(self):