minor updates in garom.py
* Removing `retain_graph` in backward for discriminator * Fixing issues with different precision training for Lightining
This commit is contained in:
committed by
Nicola Demo
parent
48b2e339b5
commit
56bd0dac7d
@@ -131,6 +131,8 @@ class GAROM(SolverInterface):
|
|||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.lambda_k = lambda_k
|
self.lambda_k = lambda_k
|
||||||
self.regularizer = float(regularizer)
|
self.regularizer = float(regularizer)
|
||||||
|
self._generator = self.models[0]
|
||||||
|
self._discriminator = self.models[1]
|
||||||
|
|
||||||
def forward(self, x, mc_steps=20, variance=False):
|
def forward(self, x, mc_steps=20, variance=False):
|
||||||
"""
|
"""
|
||||||
@@ -215,7 +217,7 @@ class GAROM(SolverInterface):
|
|||||||
d_loss = d_loss_real - self.k * d_loss_fake
|
d_loss = d_loss_real - self.k * d_loss_fake
|
||||||
|
|
||||||
# backward step
|
# backward step
|
||||||
d_loss.backward(retain_graph=True)
|
d_loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
return d_loss_real, d_loss_fake, d_loss
|
return d_loss_real, d_loss_fake, d_loss
|
||||||
@@ -251,7 +253,7 @@ class GAROM(SolverInterface):
|
|||||||
|
|
||||||
condition_name = dataloader.condition_names[condition_id]
|
condition_name = dataloader.condition_names[condition_id]
|
||||||
condition = self.problem.conditions[condition_name]
|
condition = self.problem.conditions[condition_name]
|
||||||
pts = batch['pts']
|
pts = batch['pts'].detach()
|
||||||
out = batch['output']
|
out = batch['output']
|
||||||
|
|
||||||
if condition_name not in self.problem.conditions:
|
if condition_name not in self.problem.conditions:
|
||||||
@@ -282,11 +284,11 @@ class GAROM(SolverInterface):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def generator(self):
|
def generator(self):
|
||||||
return self.models[0]
|
return self._generator
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def discriminator(self):
|
def discriminator(self):
|
||||||
return self.models[1]
|
return self._discriminator
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optimizer_generator(self):
|
def optimizer_generator(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user