minor updates in garom.py
* Removing `retain_graph` in backward for discriminator * Fixing issues with different precision training for Lightining
This commit is contained in:
committed by
Nicola Demo
parent
48b2e339b5
commit
56bd0dac7d
@@ -131,6 +131,8 @@ class GAROM(SolverInterface):
|
||||
self.gamma = gamma
|
||||
self.lambda_k = lambda_k
|
||||
self.regularizer = float(regularizer)
|
||||
self._generator = self.models[0]
|
||||
self._discriminator = self.models[1]
|
||||
|
||||
def forward(self, x, mc_steps=20, variance=False):
|
||||
"""
|
||||
@@ -215,7 +217,7 @@ class GAROM(SolverInterface):
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
d_loss.backward(retain_graph=True)
|
||||
d_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return d_loss_real, d_loss_fake, d_loss
|
||||
@@ -251,7 +253,7 @@ class GAROM(SolverInterface):
|
||||
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts']
|
||||
pts = batch['pts'].detach()
|
||||
out = batch['output']
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
@@ -282,11 +284,11 @@ class GAROM(SolverInterface):
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return self.models[0]
|
||||
return self._generator
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
return self.models[1]
|
||||
return self._discriminator
|
||||
|
||||
@property
|
||||
def optimizer_generator(self):
|
||||
|
||||
Reference in New Issue
Block a user