From 732d48c3602234a63d7af0017d2c18f074e9e652 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Fri, 12 Dec 2025 10:18:16 +0100 Subject: [PATCH] new data format --- ThermalSolver/autoregressive_module.py | 128 +++++++++++++-------- ThermalSolver/graph_datamodule_unsteady.py | 73 ++++++------ ThermalSolver/model/diffusion_net.py | 4 +- 3 files changed, 118 insertions(+), 87 deletions(-) diff --git a/ThermalSolver/autoregressive_module.py b/ThermalSolver/autoregressive_module.py index 34d0073..6f85528 100644 --- a/ThermalSolver/autoregressive_module.py +++ b/ThermalSolver/autoregressive_module.py @@ -16,48 +16,79 @@ def import_class(class_path: str): def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): - for j in [0, 10, 20, 30]: + # print(pos_.shape, y_.shape, y_pred_.shape, y_true_.shape) + for j in [0]: idx = (batch == j).nonzero(as_tuple=True)[0] y = y_[idx].detach().cpu() y_pred = y_pred_[idx].detach().cpu() pos = pos_[idx].detach().cpu() + # print(pos.shape, y.shape, y_pred.shape) y_true = y_true_[idx].detach().cpu() y_true = torch.clamp(y_true, min=0) folder = f"{j:02d}_images" if os.path.exists(folder) is False: os.makedirs(folder) - pos = pos.detach().cpu() tria = Triangulation(pos[:, 0], pos[:, 1]) - plt.figure(figsize=(24, 5)) - plt.subplot(1, 4, 1) - plt.tricontourf(tria, y.squeeze().numpy(), levels=100) - plt.colorbar() - plt.title("Step t-1") - plt.subplot(1, 4, 2) - plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) + plt.figure(figsize=(18, 6)) + # plt.subplot(1, 4, 1) + # plt.tricontourf(tria, y.squeeze().numpy(), levels=100) + # plt.colorbar() + # plt.title("Step t-1") + # plt.tripcolor(tria, y_pred.squeeze().numpy() + # plt.savefig("test_scatter_step_before.png", dpi=72) + # x = z + plt.subplot(1, 3, 1) + # plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) + plt.scatter( + pos[:, 0], + pos[:, 1], + c=y_pred.squeeze().numpy(), + s=20, + cmap="viridis", + ) plt.colorbar() plt.title("Step t Predicted") - plt.subplot(1, 4, 3) - plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) + plt.subplot(1, 3, 2) + # plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) + plt.scatter( + pos[:, 0], + pos[:, 1], + c=y_true.squeeze().numpy(), + s=20, + cmap="viridis", + ) plt.colorbar() plt.title("t True") - plt.subplot(1, 4, 4) - plt.tricontourf(tria, (y_true - y_pred).squeeze().numpy(), levels=100) + plt.subplot(1, 3, 3) + per_element_relative_error = torch.abs(y_pred - y_true) / torch.clamp( + torch.abs(y_true), min=1e-6 + ) + # plt.tricontourf(tria, per_element_relative_error.squeeze(), levels=100) + plt.scatter( + pos[:, 0], + pos[:, 1], + c=per_element_relative_error.squeeze().numpy(), + s=20, + cmap="viridis", + ) plt.colorbar() - plt.title("Error") + plt.title("Relative Error") plt.suptitle("GNO", fontsize=16) name = f"{folder}/{j:04d}_graph_iter_{i:04d}.png" plt.savefig(name, dpi=72) plt.close() -def _plot_losses(losses, batch_idx): +def _plot_losses(test_losses, batch_idx): folder = f"{batch_idx:02d}_images" plt.figure() - plt.plot(losses) + for i, losses in enumerate(test_losses): + plt.plot(losses) + if i == 3: + break plt.yscale("log") plt.xlabel("Iteration") - plt.ylabel("Loss") + plt.ylabel("Relative Error") plt.title("Test Loss over Iterations") plt.grid(True) file_name = f"{folder}/test_loss.png" @@ -80,6 +111,7 @@ class GraphSolver(LightningModule): # print(f"Param: {param[0]}") self.loss = loss if loss is not None else torch.nn.MSELoss() self.unrolling_steps = unrolling_steps + self.test_losses = [] def _compute_loss(self, x, y): return self.loss(x, y) @@ -149,7 +181,7 @@ class GraphSolver(LightningModule): self._log_loss(loss, batch, "train") for i, layer in enumerate(self.model.layers): self.log( - f"alpha_{i}", + f"{i:03d}_alpha", layer.alpha, prog_bar=True, on_epoch=True, @@ -205,10 +237,10 @@ class GraphSolver(LightningModule): self._log_loss(loss, batch, "val") return loss - def _check_convergence(self, y_pred, y_true, tol=1e-3): - l2_norm = torch.norm(y_pred - y_true, p=2) - y_true_norm = torch.norm(y_true, p=2) - rel_error = l2_norm / (y_true_norm + 1e-8) + def _check_convergence(self, y_new, y_old, tol=1e-3): + l2_norm = torch.norm(y_new, p=2) - torch.norm(y_old, p=2) + y_old_norm = torch.norm(y_old, p=2) + rel_error = l2_norm / (y_old_norm) return rel_error.item() < tol def test_step(self, batch: Batch, batch_idx): @@ -219,7 +251,9 @@ class GraphSolver(LightningModule): losses = [] all_losses = [] norms = [] - for i in range(self.unrolling_steps): + sequence_length = y.size(1) + y = y[:, -1, :].unsqueeze(1) + for i in range(100): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, @@ -231,34 +265,38 @@ class GraphSolver(LightningModule): conductivity, ) norms.append(torch.norm(out - x, p=2).item()) + converged = self._check_convergence(out, x) + if batch_idx == 0: + _plot_mesh( + batch.pos, + x, + out, + y[:, -1, :], + batch.batch, + i, + self.current_epoch, + ) x = out - loss = self.loss(out, y[:, i, :]) - all_losses.append(loss.item()) + loss = self.loss(out, y[:, -1, :]) + relative_error = torch.norm(out - y[:, -1, :], p=2) / torch.norm( + y[:, -1, :], p=2 + ) + all_losses.append(relative_error.item()) losses.append(loss) - # if ( - # batch_idx == 0 - # and self.current_epoch % 10 == 0 - # and self.current_epoch > 0 - # ): - # _plot_mesh( - # batch.pos, - # x, - # out, - # y[:, i, :], - # batch.batch, - # i, - # self.current_epoch, - # ) + if converged: + print( + f"Test step converged at iteration {i} for batch {batch_idx}" + ) + break loss = torch.stack(losses).mean() - # if ( - # batch_idx == 0 - # and self.current_epoch % 10 == 0 - # and self.current_epoch > 0 - # ): - _plot_losses(norms, self.current_epoch) + self.test_losses.append(all_losses) self._log_loss(loss, batch, "test") return loss + def on_test_end(self): + if len(self.test_losses) > 0: + _plot_losses(self.test_losses, batch_idx=0) + def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) return optimizer diff --git a/ThermalSolver/graph_datamodule_unsteady.py b/ThermalSolver/graph_datamodule_unsteady.py index 1fc94f8..4026a1d 100644 --- a/ThermalSolver/graph_datamodule_unsteady.py +++ b/ThermalSolver/graph_datamodule_unsteady.py @@ -64,28 +64,6 @@ class GraphDataModule(LightningDataModule): "test": geometry.select(range(train_len + valid_len, total_len)), } - def _compute_boundary_mask( - self, bottom_ids, right_ids, top_ids, left_ids, temperature - ): - left_ids = left_ids[~torch.isin(left_ids, bottom_ids)] - right_ids = right_ids[~torch.isin(right_ids, bottom_ids)] - left_ids = left_ids[~torch.isin(left_ids, top_ids)] - right_ids = right_ids[~torch.isin(right_ids, top_ids)] - - bottom_bc = temperature[bottom_ids].median() - bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc - left_bc = temperature[left_ids].median() - left_bc_mask = torch.ones(len(left_ids)) * left_bc - right_bc = temperature[right_ids].median() - right_bc_mask = torch.ones(len(right_ids)) * right_bc - - boundary_values = torch.cat( - [bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0 - ) - boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0) - - return boundary_mask, boundary_values - def _build_dataset( self, snapshot: dict, @@ -96,25 +74,22 @@ class GraphDataModule(LightningDataModule): geometry["conductivity"], dtype=torch.float32 ) temperatures = ( - torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:40] + torch.tensor(snapshot["unsteady"], dtype=torch.float32) if not test - else torch.tensor(snapshot["temperatures"], dtype=torch.float32)[ - : self.unrolling_steps + 1 - ] + else torch.stack( + [ + torch.tensor(snapshot["unsteady"], dtype=torch.float32)[ + 0, ... + ], + torch.tensor(snapshot["steady"], dtype=torch.float32), + ], + dim=0, + ) ) - times = torch.tensor(snapshot["times"], dtype=torch.float32) + print(temperatures.shape) pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] - bottom_ids = torch.tensor( - geometry["bottom_boundary_ids"], dtype=torch.long - ) - top_ids = torch.tensor(geometry["top_boundary_ids"], dtype=torch.long) - left_ids = torch.tensor(geometry["left_boundary_ids"], dtype=torch.long) - right_ids = torch.tensor( - geometry["right_boundary_ids"], dtype=torch.long - ) - if self.build_radial_graph: raise NotImplementedError( "Radial graph building not implemented yet." @@ -125,17 +100,37 @@ class GraphDataModule(LightningDataModule): ).T edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) - boundary_mask, boundary_values = self._compute_boundary_mask( - bottom_ids, right_ids, top_ids, left_ids, temperatures[0, :] + boundary_mask = torch.tensor( + geometry["constraints_mask"], dtype=torch.int64 ) + boundary_values = torch.tensor( + geometry["constraints_values"], dtype=torch.float32 + ) + edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1) if self.remove_boundary_edges: boundary_idx = torch.unique(boundary_mask) edge_index_mask = ~torch.isin(edge_index[1], boundary_idx) edge_index = edge_index[:, edge_index_mask] edge_attr = edge_attr[edge_index_mask] - n_data = temperatures.size(0) - self.unrolling_steps + + n_data = max(temperatures.size(0) - self.unrolling_steps, 1) data = [] + + if test: + data.append( + MeshData( + x=temperatures[0, :].unsqueeze(-1), + y=temperatures[1:2, :].unsqueeze(-1).permute(1, 0, 2), + c=conductivity.unsqueeze(-1), + edge_index=edge_index, + pos=pos, + edge_attr=edge_attr, + boundary_mask=boundary_mask, + boundary_values=boundary_values, + ) + ) + return data for i in range(n_data): x = temperatures[i, :].unsqueeze(-1) y = ( diff --git a/ThermalSolver/model/diffusion_net.py b/ThermalSolver/model/diffusion_net.py index 4a2bb36..644b7c5 100644 --- a/ThermalSolver/model/diffusion_net.py +++ b/ThermalSolver/model/diffusion_net.py @@ -33,14 +33,12 @@ class DiffusionLayer(MessagePassing): @property def alpha(self): - return torch.clamp(self.alpha_param, min=1e-5, max=1.0) + return torch.clamp(self.alpha_param, min=1e-7, max=1.0) def forward(self, x, edge_index, edge_weight, conductivity): edge_weight = edge_weight.unsqueeze(-1) conductance = self.phys_encoder(edge_weight) net_flux = self.propagate(edge_index, x=x, conductance=conductance) - # return (1-self.alpha) * x + self.alpha * net_flux - # return net_flux + x return x + self.alpha * net_flux def message(self, x_i, x_j, conductance):