diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index 4e41af8..433e2bc 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -131,6 +131,8 @@ class GAROM(SolverInterface): self.gamma = gamma self.lambda_k = lambda_k self.regularizer = float(regularizer) + self._generator = self.models[0] + self._discriminator = self.models[1] def forward(self, x, mc_steps=20, variance=False): """ @@ -215,7 +217,7 @@ class GAROM(SolverInterface): d_loss = d_loss_real - self.k * d_loss_fake # backward step - d_loss.backward(retain_graph=True) + d_loss.backward() optimizer.step() return d_loss_real, d_loss_fake, d_loss @@ -251,7 +253,7 @@ class GAROM(SolverInterface): condition_name = dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] - pts = batch['pts'] + pts = batch['pts'].detach() out = batch['output'] if condition_name not in self.problem.conditions: @@ -282,11 +284,11 @@ class GAROM(SolverInterface): @property def generator(self): - return self.models[0] + return self._generator @property def discriminator(self): - return self.models[1] + return self._discriminator @property def optimizer_generator(self):