minor updates in garom.py

* Removing `retain_graph` in backward for discriminator
* Fixing issues with different precision training for Lightining
This commit is contained in:
Dario Coscia
2023-11-17 15:44:47 +01:00
committed by Nicola Demo
parent 48b2e339b5
commit 56bd0dac7d

View File

@@ -131,6 +131,8 @@ class GAROM(SolverInterface):
self.gamma = gamma self.gamma = gamma
self.lambda_k = lambda_k self.lambda_k = lambda_k
self.regularizer = float(regularizer) self.regularizer = float(regularizer)
self._generator = self.models[0]
self._discriminator = self.models[1]
def forward(self, x, mc_steps=20, variance=False): def forward(self, x, mc_steps=20, variance=False):
""" """
@@ -215,7 +217,7 @@ class GAROM(SolverInterface):
d_loss = d_loss_real - self.k * d_loss_fake d_loss = d_loss_real - self.k * d_loss_fake
# backward step # backward step
d_loss.backward(retain_graph=True) d_loss.backward()
optimizer.step() optimizer.step()
return d_loss_real, d_loss_fake, d_loss return d_loss_real, d_loss_fake, d_loss
@@ -251,7 +253,7 @@ class GAROM(SolverInterface):
condition_name = dataloader.condition_names[condition_id] condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name] condition = self.problem.conditions[condition_name]
pts = batch['pts'] pts = batch['pts'].detach()
out = batch['output'] out = batch['output']
if condition_name not in self.problem.conditions: if condition_name not in self.problem.conditions:
@@ -282,11 +284,11 @@ class GAROM(SolverInterface):
@property @property
def generator(self): def generator(self):
return self.models[0] return self._generator
@property @property
def discriminator(self): def discriminator(self):
return self.models[1] return self._discriminator
@property @property
def optimizer_generator(self): def optimizer_generator(self):