fix device problem of residual weights
This commit is contained in:
committed by
Dario Coscia
parent
4ad939fdb9
commit
419ac7ffbb
@@ -140,6 +140,21 @@ class RBAPINN(PINN):
|
|||||||
# Set the loss function to return non-aggregated losses
|
# Set the loss function to return non-aggregated losses
|
||||||
self._loss_fn = type(self._loss_fn)(reduction="none")
|
self._loss_fn = type(self._loss_fn)(reduction="none")
|
||||||
|
|
||||||
|
def on_train_start(self):
|
||||||
|
"""
|
||||||
|
Ensure that all residual weight buffers registered during initialization
|
||||||
|
are moved to the correct computation device.
|
||||||
|
"""
|
||||||
|
# Move all weight buffers to the correct device
|
||||||
|
for cond in self.problem.input_pts:
|
||||||
|
|
||||||
|
# Get the buffer for the current condition
|
||||||
|
weight_buf = getattr(self, f"weight_{cond}")
|
||||||
|
|
||||||
|
# Move the buffer to the correct device
|
||||||
|
weight_buf.data = weight_buf.data.to(self.device)
|
||||||
|
self.weights[cond] = weight_buf
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, **kwargs):
|
def training_step(self, batch, batch_idx, **kwargs):
|
||||||
"""
|
"""
|
||||||
Solver training step. It computes the optimization cycle and aggregates
|
Solver training step. It computes the optimization cycle and aggregates
|
||||||
@@ -235,7 +250,7 @@ class RBAPINN(PINN):
|
|||||||
idx = torch.arange(
|
idx = torch.arange(
|
||||||
batch_idx * len_res,
|
batch_idx * len_res,
|
||||||
(batch_idx + 1) * len_res,
|
(batch_idx + 1) * len_res,
|
||||||
device=res.device,
|
device=self.weights[cond].device,
|
||||||
) % len(self.problem.input_pts[cond])
|
) % len(self.problem.input_pts[cond])
|
||||||
|
|
||||||
losses[cond] = self._apply_reduction(
|
losses[cond] = self._apply_reduction(
|
||||||
@@ -271,7 +286,7 @@ class RBAPINN(PINN):
|
|||||||
|
|
||||||
# Compute normalized residuals
|
# Compute normalized residuals
|
||||||
res = residuals[cond]
|
res = residuals[cond]
|
||||||
res_abs = res.abs()
|
res_abs = torch.linalg.vector_norm(res, ord=2, dim=1, keepdim=True)
|
||||||
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)
|
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)
|
||||||
|
|
||||||
# Get the correct indices for the weights. Modulus is used according
|
# Get the correct indices for the weights. Modulus is used according
|
||||||
@@ -280,7 +295,7 @@ class RBAPINN(PINN):
|
|||||||
idx = torch.arange(
|
idx = torch.arange(
|
||||||
batch_idx * len_pts,
|
batch_idx * len_pts,
|
||||||
(batch_idx + 1) * len_pts,
|
(batch_idx + 1) * len_pts,
|
||||||
device=res.device,
|
device=self.weights[cond].device,
|
||||||
) % len(self.problem.input_pts[cond])
|
) % len(self.problem.input_pts[cond])
|
||||||
|
|
||||||
# Update weights
|
# Update weights
|
||||||
|
|||||||
Reference in New Issue
Block a user