diff --git a/pina/solvers/pinns/rbapinn.py b/pina/solvers/pinns/rbapinn.py index 770b0a7..527f2f4 100644 --- a/pina/solvers/pinns/rbapinn.py +++ b/pina/solvers/pinns/rbapinn.py @@ -117,7 +117,7 @@ class RBAPINN(PINN): # initialize weights self.weights = {} for condition_name in problem.conditions: - self.weights[condition_name] = 0 + self.weights[condition_name] = 1 # define vectorial loss self._vectorial_loss = deepcopy(loss) @@ -158,7 +158,7 @@ class RBAPINN(PINN): residual = self.compute_residual(samples=samples, equation=equation) cond = self.current_condition_name - r_norm = self.eta * torch.abs(residual) / torch.max(torch.abs(residual)) + r_norm = self.eta * torch.abs(residual) / (torch.max(torch.abs(residual))+1e-12) self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach() loss_value = self._vectorial_loss(