From 56bd0dac7d1f6f2c534dd6f93d312d1df57aba5f Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Fri, 17 Nov 2023 15:44:47 +0100 Subject: [PATCH] minor updates in garom.py * Removing `retain_graph` in backward for discriminator * Fixing issues with different precision training for Lightining --- pina/solvers/garom.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index 4e41af8..433e2bc 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -131,6 +131,8 @@ class GAROM(SolverInterface): self.gamma = gamma self.lambda_k = lambda_k self.regularizer = float(regularizer) + self._generator = self.models[0] + self._discriminator = self.models[1] def forward(self, x, mc_steps=20, variance=False): """ @@ -215,7 +217,7 @@ class GAROM(SolverInterface): d_loss = d_loss_real - self.k * d_loss_fake # backward step - d_loss.backward(retain_graph=True) + d_loss.backward() optimizer.step() return d_loss_real, d_loss_fake, d_loss @@ -251,7 +253,7 @@ class GAROM(SolverInterface): condition_name = dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] - pts = batch['pts'] + pts = batch['pts'].detach() out = batch['output'] if condition_name not in self.problem.conditions: @@ -282,11 +284,11 @@ class GAROM(SolverInterface): @property def generator(self): - return self.models[0] + return self._generator @property def discriminator(self): - return self.models[1] + return self._discriminator @property def optimizer_generator(self):