fix module and model + add curriculum callback

This commit is contained in:
Filippo Olivo
2025-12-09 09:18:36 +01:00
parent 2935785b31
commit f2ce282a68
4 changed files with 243 additions and 111 deletions

View File

@@ -116,13 +116,12 @@ class GraphSolver(LightningModule):
return out
def _preprocess_batch(self, batch: Batch):
x, y, c, edge_index, edge_attr, nodal_area = (
x, y, c, edge_index, edge_attr = (
batch.x,
batch.y,
batch.c,
batch.edge_index,
batch.edge_attr,
batch.nodal_area,
)
edge_attr = 1 / edge_attr
conductivity = self._compute_c_ij(c, edge_index)
@@ -133,34 +132,7 @@ class GraphSolver(LightningModule):
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
batch
)
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = []
# print(x.shape, y.shape)
# # print(torch.max(edge_index), torch.min(edge_index))
# plt.figure()
# plt.subplot(2,3,1)
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=x.squeeze().cpu())
# plt.subplot(2,3,2)
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,0,:].squeeze().cpu())
# plt.subplot(2,3,3)
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,1,:].squeeze().cpu())
# plt.subplot(2,3,4)
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,2,:].squeeze().cpu())
# plt.subplot(2,3,5)
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,3,:].squeeze().cpu())
# plt.subplot(2,3,6)
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,4,:].squeeze().cpu())
# plt.suptitle("Training Batch Visualization", fontsize=16)
# plt.savefig("training_batch_visualization.png", dpi=300)
# plt.close()
# y = z
pos = batch.pos
boundary_mask = batch.boundary_mask
boundary_values = batch.boundary_values
# plt.scatter(pos[boundary_mask,0].cpu(), pos[boundary_mask,1].cpu(), c=boundary_values.cpu(), s=1)
# plt.savefig("boundary_nodes.png", dpi=300)
# y = z
scale = 50
for i in range(self.unrolling_steps):
out = self._compute_model_steps(
x,
@@ -172,15 +144,26 @@ class GraphSolver(LightningModule):
conductivity,
)
x = out
# print(out.shape, y[:, i, :].shape)
losses.append(self.loss(out.flatten(), y[:, i, :].flatten()))
# print(self.model.scale_edge_attr.item())
loss = torch.stack(losses).mean()
# for param in self.model.parameters():
# print(f"Param: {param.shape}, Grad: {param.grad}")
# print(f"Param: {param[0]}")
self._log_loss(loss, batch, "train")
for i, layer in enumerate(self.model.layers):
self.log(
f"alpha_{i}",
layer.alpha,
prog_bar=True,
on_epoch=True,
on_step=False,
batch_size=int(batch.num_graphs),
)
self.log(
"dt",
self.model.dt,
prog_bar=True,
on_epoch=True,
on_step=False,
batch_size=int(batch.num_graphs),
)
return loss
def validation_step(self, batch: Batch, batch_idx):
@@ -222,8 +205,59 @@ class GraphSolver(LightningModule):
self._log_loss(loss, batch, "val")
return loss
def _check_convergence(self, y_pred, y_true, tol=1e-3):
l2_norm = torch.norm(y_pred - y_true, p=2)
y_true_norm = torch.norm(y_true, p=2)
rel_error = l2_norm / (y_true_norm + 1e-8)
return rel_error.item() < tol
def test_step(self, batch: Batch, batch_idx):
pass
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
batch
)
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = []
all_losses = []
norms = []
for i in range(self.unrolling_steps):
out = self._compute_model_steps(
# torch.cat([x,pos], dim=-1),
x,
edge_index,
edge_attr,
# deg,
batch.boundary_mask,
batch.boundary_values,
conductivity,
)
norms.append(torch.norm(out - x, p=2).item())
x = out
loss = self.loss(out, y[:, i, :])
all_losses.append(loss.item())
losses.append(loss)
# if (
# batch_idx == 0
# and self.current_epoch % 10 == 0
# and self.current_epoch > 0
# ):
# _plot_mesh(
# batch.pos,
# x,
# out,
# y[:, i, :],
# batch.batch,
# i,
# self.current_epoch,
# )
loss = torch.stack(losses).mean()
# if (
# batch_idx == 0
# and self.current_epoch % 10 == 0
# and self.current_epoch > 0
# ):
_plot_losses(norms, self.current_epoch)
self._log_loss(loss, batch, "test")
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)