diff --git a/pina/solvers/pinns/rbapinn.py b/pina/solvers/pinns/rbapinn.py index b651b0b..fd551ac 100644 --- a/pina/solvers/pinns/rbapinn.py +++ b/pina/solvers/pinns/rbapinn.py @@ -158,7 +158,11 @@ class RBAPINN(PINN): residual = self.compute_residual(samples=samples, equation=equation) cond = self.current_condition_name - r_norm = self.eta * torch.abs(residual) / (torch.max(torch.abs(residual))+1e-12) + r_norm = ( + self.eta + * torch.abs(residual) + / (torch.max(torch.abs(residual)) + 1e-12) + ) self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach() loss_value = self._vectorial_loss(