fix module and model + add curriculum callback
This commit is contained in:
@@ -116,13 +116,12 @@ class GraphSolver(LightningModule):
|
||||
return out
|
||||
|
||||
def _preprocess_batch(self, batch: Batch):
|
||||
x, y, c, edge_index, edge_attr, nodal_area = (
|
||||
x, y, c, edge_index, edge_attr = (
|
||||
batch.x,
|
||||
batch.y,
|
||||
batch.c,
|
||||
batch.edge_index,
|
||||
batch.edge_attr,
|
||||
batch.nodal_area,
|
||||
)
|
||||
edge_attr = 1 / edge_attr
|
||||
conductivity = self._compute_c_ij(c, edge_index)
|
||||
@@ -133,34 +132,7 @@ class GraphSolver(LightningModule):
|
||||
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
|
||||
batch
|
||||
)
|
||||
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
||||
losses = []
|
||||
# print(x.shape, y.shape)
|
||||
# # print(torch.max(edge_index), torch.min(edge_index))
|
||||
# plt.figure()
|
||||
# plt.subplot(2,3,1)
|
||||
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=x.squeeze().cpu())
|
||||
# plt.subplot(2,3,2)
|
||||
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,0,:].squeeze().cpu())
|
||||
# plt.subplot(2,3,3)
|
||||
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,1,:].squeeze().cpu())
|
||||
# plt.subplot(2,3,4)
|
||||
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,2,:].squeeze().cpu())
|
||||
# plt.subplot(2,3,5)
|
||||
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,3,:].squeeze().cpu())
|
||||
# plt.subplot(2,3,6)
|
||||
# plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,4,:].squeeze().cpu())
|
||||
# plt.suptitle("Training Batch Visualization", fontsize=16)
|
||||
# plt.savefig("training_batch_visualization.png", dpi=300)
|
||||
# plt.close()
|
||||
# y = z
|
||||
pos = batch.pos
|
||||
boundary_mask = batch.boundary_mask
|
||||
boundary_values = batch.boundary_values
|
||||
# plt.scatter(pos[boundary_mask,0].cpu(), pos[boundary_mask,1].cpu(), c=boundary_values.cpu(), s=1)
|
||||
# plt.savefig("boundary_nodes.png", dpi=300)
|
||||
# y = z
|
||||
scale = 50
|
||||
for i in range(self.unrolling_steps):
|
||||
out = self._compute_model_steps(
|
||||
x,
|
||||
@@ -172,15 +144,26 @@ class GraphSolver(LightningModule):
|
||||
conductivity,
|
||||
)
|
||||
x = out
|
||||
# print(out.shape, y[:, i, :].shape)
|
||||
losses.append(self.loss(out.flatten(), y[:, i, :].flatten()))
|
||||
# print(self.model.scale_edge_attr.item())
|
||||
|
||||
loss = torch.stack(losses).mean()
|
||||
# for param in self.model.parameters():
|
||||
# print(f"Param: {param.shape}, Grad: {param.grad}")
|
||||
# print(f"Param: {param[0]}")
|
||||
self._log_loss(loss, batch, "train")
|
||||
for i, layer in enumerate(self.model.layers):
|
||||
self.log(
|
||||
f"alpha_{i}",
|
||||
layer.alpha,
|
||||
prog_bar=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
batch_size=int(batch.num_graphs),
|
||||
)
|
||||
self.log(
|
||||
"dt",
|
||||
self.model.dt,
|
||||
prog_bar=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
batch_size=int(batch.num_graphs),
|
||||
)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Batch, batch_idx):
|
||||
@@ -222,8 +205,59 @@ class GraphSolver(LightningModule):
|
||||
self._log_loss(loss, batch, "val")
|
||||
return loss
|
||||
|
||||
def _check_convergence(self, y_pred, y_true, tol=1e-3):
|
||||
l2_norm = torch.norm(y_pred - y_true, p=2)
|
||||
y_true_norm = torch.norm(y_true, p=2)
|
||||
rel_error = l2_norm / (y_true_norm + 1e-8)
|
||||
return rel_error.item() < tol
|
||||
|
||||
def test_step(self, batch: Batch, batch_idx):
|
||||
pass
|
||||
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
|
||||
batch
|
||||
)
|
||||
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
||||
losses = []
|
||||
all_losses = []
|
||||
norms = []
|
||||
for i in range(self.unrolling_steps):
|
||||
out = self._compute_model_steps(
|
||||
# torch.cat([x,pos], dim=-1),
|
||||
x,
|
||||
edge_index,
|
||||
edge_attr,
|
||||
# deg,
|
||||
batch.boundary_mask,
|
||||
batch.boundary_values,
|
||||
conductivity,
|
||||
)
|
||||
norms.append(torch.norm(out - x, p=2).item())
|
||||
x = out
|
||||
loss = self.loss(out, y[:, i, :])
|
||||
all_losses.append(loss.item())
|
||||
losses.append(loss)
|
||||
# if (
|
||||
# batch_idx == 0
|
||||
# and self.current_epoch % 10 == 0
|
||||
# and self.current_epoch > 0
|
||||
# ):
|
||||
# _plot_mesh(
|
||||
# batch.pos,
|
||||
# x,
|
||||
# out,
|
||||
# y[:, i, :],
|
||||
# batch.batch,
|
||||
# i,
|
||||
# self.current_epoch,
|
||||
# )
|
||||
loss = torch.stack(losses).mean()
|
||||
# if (
|
||||
# batch_idx == 0
|
||||
# and self.current_epoch % 10 == 0
|
||||
# and self.current_epoch > 0
|
||||
# ):
|
||||
_plot_losses(norms, self.current_epoch)
|
||||
self._log_loss(loss, batch, "test")
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user