fix device problem of residual weights

This commit is contained in:
Giovanni Canali
2025-07-28 09:51:38 +02:00
committed by Dario Coscia
parent 4ad939fdb9
commit 419ac7ffbb

View File

@@ -140,6 +140,21 @@ class RBAPINN(PINN):
# Set the loss function to return non-aggregated losses
self._loss_fn = type(self._loss_fn)(reduction="none")
def on_train_start(self):
"""
Ensure that all residual weight buffers registered during initialization
are moved to the correct computation device.
"""
# Move all weight buffers to the correct device
for cond in self.problem.input_pts:
# Get the buffer for the current condition
weight_buf = getattr(self, f"weight_{cond}")
# Move the buffer to the correct device
weight_buf.data = weight_buf.data.to(self.device)
self.weights[cond] = weight_buf
def training_step(self, batch, batch_idx, **kwargs):
"""
Solver training step. It computes the optimization cycle and aggregates
@@ -235,7 +250,7 @@ class RBAPINN(PINN):
idx = torch.arange(
batch_idx * len_res,
(batch_idx + 1) * len_res,
device=res.device,
device=self.weights[cond].device,
) % len(self.problem.input_pts[cond])
losses[cond] = self._apply_reduction(
@@ -271,7 +286,7 @@ class RBAPINN(PINN):
# Compute normalized residuals
res = residuals[cond]
res_abs = res.abs()
res_abs = torch.linalg.vector_norm(res, ord=2, dim=1, keepdim=True)
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)
# Get the correct indices for the weights. Modulus is used according
@@ -280,7 +295,7 @@ class RBAPINN(PINN):
idx = torch.arange(
batch_idx * len_pts,
(batch_idx + 1) * len_pts,
device=res.device,
device=self.weights[cond].device,
) % len(self.problem.input_pts[cond])
# Update weights