From 419ac7ffbbbbacd4a4bac0ce37753d72718aa75c Mon Sep 17 00:00:00 2001 From: Giovanni Canali Date: Mon, 28 Jul 2025 09:51:38 +0200 Subject: [PATCH] fix device problem of residual weights --- .../physics_informed_solver/rba_pinn.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/pina/solver/physics_informed_solver/rba_pinn.py b/pina/solver/physics_informed_solver/rba_pinn.py index 831f7d4..5c8d50f 100644 --- a/pina/solver/physics_informed_solver/rba_pinn.py +++ b/pina/solver/physics_informed_solver/rba_pinn.py @@ -140,6 +140,21 @@ class RBAPINN(PINN): # Set the loss function to return non-aggregated losses self._loss_fn = type(self._loss_fn)(reduction="none") + def on_train_start(self): + """ + Ensure that all residual weight buffers registered during initialization + are moved to the correct computation device. + """ + # Move all weight buffers to the correct device + for cond in self.problem.input_pts: + + # Get the buffer for the current condition + weight_buf = getattr(self, f"weight_{cond}") + + # Move the buffer to the correct device + weight_buf.data = weight_buf.data.to(self.device) + self.weights[cond] = weight_buf + def training_step(self, batch, batch_idx, **kwargs): """ Solver training step. It computes the optimization cycle and aggregates @@ -235,7 +250,7 @@ class RBAPINN(PINN): idx = torch.arange( batch_idx * len_res, (batch_idx + 1) * len_res, - device=res.device, + device=self.weights[cond].device, ) % len(self.problem.input_pts[cond]) losses[cond] = self._apply_reduction( @@ -271,7 +286,7 @@ class RBAPINN(PINN): # Compute normalized residuals res = residuals[cond] - res_abs = res.abs() + res_abs = torch.linalg.vector_norm(res, ord=2, dim=1, keepdim=True) r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12) # Get the correct indices for the weights. Modulus is used according @@ -280,7 +295,7 @@ class RBAPINN(PINN): idx = torch.arange( batch_idx * len_pts, (batch_idx + 1) * len_pts, - device=res.device, + device=self.weights[cond].device, ) % len(self.problem.input_pts[cond]) # Update weights