fix device problem of residual weights
This commit is contained in:
committed by
Dario Coscia
parent
4ad939fdb9
commit
419ac7ffbb
@@ -140,6 +140,21 @@ class RBAPINN(PINN):
|
||||
# Set the loss function to return non-aggregated losses
|
||||
self._loss_fn = type(self._loss_fn)(reduction="none")
|
||||
|
||||
def on_train_start(self):
|
||||
"""
|
||||
Ensure that all residual weight buffers registered during initialization
|
||||
are moved to the correct computation device.
|
||||
"""
|
||||
# Move all weight buffers to the correct device
|
||||
for cond in self.problem.input_pts:
|
||||
|
||||
# Get the buffer for the current condition
|
||||
weight_buf = getattr(self, f"weight_{cond}")
|
||||
|
||||
# Move the buffer to the correct device
|
||||
weight_buf.data = weight_buf.data.to(self.device)
|
||||
self.weights[cond] = weight_buf
|
||||
|
||||
def training_step(self, batch, batch_idx, **kwargs):
|
||||
"""
|
||||
Solver training step. It computes the optimization cycle and aggregates
|
||||
@@ -235,7 +250,7 @@ class RBAPINN(PINN):
|
||||
idx = torch.arange(
|
||||
batch_idx * len_res,
|
||||
(batch_idx + 1) * len_res,
|
||||
device=res.device,
|
||||
device=self.weights[cond].device,
|
||||
) % len(self.problem.input_pts[cond])
|
||||
|
||||
losses[cond] = self._apply_reduction(
|
||||
@@ -271,7 +286,7 @@ class RBAPINN(PINN):
|
||||
|
||||
# Compute normalized residuals
|
||||
res = residuals[cond]
|
||||
res_abs = res.abs()
|
||||
res_abs = torch.linalg.vector_norm(res, ord=2, dim=1, keepdim=True)
|
||||
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)
|
||||
|
||||
# Get the correct indices for the weights. Modulus is used according
|
||||
@@ -280,7 +295,7 @@ class RBAPINN(PINN):
|
||||
idx = torch.arange(
|
||||
batch_idx * len_pts,
|
||||
(batch_idx + 1) * len_pts,
|
||||
device=res.device,
|
||||
device=self.weights[cond].device,
|
||||
) % len(self.problem.input_pts[cond])
|
||||
|
||||
# Update weights
|
||||
|
||||
Reference in New Issue
Block a user