From 92104a6b06a207806996ac0e56cc68049d8864fd Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 19 Dec 2025 15:50:47 +0100 Subject: [PATCH] new plotting strategy --- ThermalSolver/autoregressive_module.py | 82 +++++++++------------- ThermalSolver/graph_datamodule.py | 4 +- ThermalSolver/graph_datamodule_unsteady.py | 8 ++- 3 files changed, 45 insertions(+), 49 deletions(-) diff --git a/ThermalSolver/autoregressive_module.py b/ThermalSolver/autoregressive_module.py index 2247bd1..73b260f 100644 --- a/ThermalSolver/autoregressive_module.py +++ b/ThermalSolver/autoregressive_module.py @@ -15,7 +15,7 @@ def import_class(class_path: str): return cls -def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): +def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, cells, i, batch_idx): # print(pos_.shape, y_.shape, y_pred_.shape, y_true_.shape) for j in [0]: idx = (batch == j).nonzero(as_tuple=True)[0] @@ -28,7 +28,8 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): folder = f"{batch_idx:02d}_images" if os.path.exists(folder) is False: os.makedirs(folder) - tria = Triangulation(pos[:, 0], pos[:, 1]) + triangles = torch.vstack([cells[:, [0, 1, 2]], cells[:, [0, 2, 3]]]) + tria = Triangulation(pos[:, 0], pos[:, 1], triangles=triangles) plt.figure(figsize=(24, 6)) # plt.subplot(1, 4, 1) # plt.tricontourf(tria, y.squeeze().numpy(), levels=100) @@ -38,51 +39,36 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): # plt.savefig("test_scatter_step_before.png", dpi=72) # x = z plt.subplot(1, 4, 1) - # plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) - plt.scatter( - pos[:, 0], - pos[:, 1], - c=y_pred.squeeze().numpy(), - s=20, - cmap="viridis", - ) + plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) + # plt.scatter(pos[:, 0], pos[:, 1], c=y_pred.squeeze().numpy(), s=20, cmap="viridis",) plt.colorbar() plt.title(f"Prediction at timestep {i:03d}") plt.subplot(1, 4, 2) - # plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) - plt.scatter( - pos[:, 0], - pos[:, 1], - c=y_true.squeeze().numpy(), - s=20, - cmap="viridis", - ) + plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) + # plt.scatter(pos[:, 0], pos[:, 1], c=y_true.squeeze().numpy(), s=20, cmap="viridis") plt.colorbar() plt.title("Ground Truth Steady State") plt.subplot(1, 4, 3) - per_element_relative_error = torch.abs(y_pred - y_true) / (y_true + 1e-6) - # plt.tricontourf(tria, per_element_relative_error.squeeze(), levels=100) - plt.scatter( - pos[:, 0], - pos[:, 1], - c=per_element_relative_error.squeeze().numpy(), - s=20, - cmap="viridis", + per_element_relative_error = torch.abs(y_pred - y_true) / ( + y_true + 1e-6 + ) + per_element_relative_error = torch.clamp( + per_element_relative_error, max=1.0, min=0.0 + ) + plt.tricontourf( + tria, + per_element_relative_error.squeeze(), + levels=100, vmin=0, vmax=1.0, ) + # plt.scatter(pos[:, 0], pos[:, 1], c=per_element_relative_error.squeeze().numpy(), s=20, cmap="viridis", vmin=0, vmax=1.0) plt.colorbar() plt.title("Relative Error") plt.subplot(1, 4, 4) absolute_error = torch.abs(y_pred - y_true) - # plt.tricontourf(tria, absolute_error.squeeze(), levels=100) - plt.scatter( - pos[:, 0], - pos[:, 1], - c=absolute_error.squeeze().numpy(), - s=20, - cmap="viridis", - ) + plt.tricontourf(tria, absolute_error.squeeze(), levels=100) + # plt.scatter(pos[:, 0], pos[:, 1], c=absolute_error.squeeze().numpy(), s=20, cmap="viridis") plt.colorbar() plt.title("Absolute Error") plt.suptitle("GNO", fontsize=16) @@ -292,15 +278,9 @@ class GraphSolver(LightningModule): sequence_length = y.size(1) y = y[:, -1, :].unsqueeze(1) _plot_mesh( - batch.pos, - x, - x, - y[:, -1, :], - batch.batch, - 0, - batch_idx - ) - for i in range(100): + batch.pos, x, x, y[:, -1, :], batch.batch, batch.cells, 0, batch_idx + ) + for i in range(200): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, @@ -322,12 +302,15 @@ class GraphSolver(LightningModule): out, y[:, -1, :], batch.batch, - i+1, - batch_idx + batch.cells, + i + 1, + batch_idx, ) x = out loss = self.loss(out, y[:, -1, :]) - relative_error = torch.abs(out - y[:, -1, :]) / (torch.abs(y[:, -1, :]) + 1e-6) + relative_error = torch.abs(out - y[:, -1, :]) / ( + torch.abs(y[:, -1, :]) + 1e-6 + ) mean_relative_error = relative_error.mean() all_losses.append(mean_relative_error.item()) losses.append(loss) @@ -345,7 +328,12 @@ class GraphSolver(LightningModule): def on_test_end(self): if len(self.test_losses) > 0: - _plot_losses(self.test_relative_errors, self.test_losses, self.test_relative_updates, batch_idx=0) + _plot_losses( + self.test_relative_errors, + self.test_losses, + self.test_relative_updates, + batch_idx=0, + ) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) diff --git a/ThermalSolver/graph_datamodule.py b/ThermalSolver/graph_datamodule.py index 53b40a7..1e8a961 100644 --- a/ThermalSolver/graph_datamodule.py +++ b/ThermalSolver/graph_datamodule.py @@ -82,7 +82,9 @@ class GraphDataModule(LightningDataModule): conductivity = torch.tensor( snapshot["conductivity"], dtype=torch.float32 ) - temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)[:50] + temperature = torch.tensor( + snapshot["temperature"], dtype=torch.float32 + )[:50] pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] diff --git a/ThermalSolver/graph_datamodule_unsteady.py b/ThermalSolver/graph_datamodule_unsteady.py index e3c69cc..1935077 100644 --- a/ThermalSolver/graph_datamodule_unsteady.py +++ b/ThermalSolver/graph_datamodule_unsteady.py @@ -117,7 +117,9 @@ class GraphDataModule(LightningDataModule): if not test: for t in range(1, temperatures.size(0)): diff = temperatures[t, :] - temperatures[t - 1, :] - norm_diff = torch.norm(diff, p=2) / torch.norm(temperatures[t - 1], p=2) + norm_diff = torch.norm(diff, p=2) / torch.norm( + temperatures[t - 1], p=2 + ) if norm_diff < self.min_normalized_diff: temperatures = temperatures[: t + 1, :] break @@ -148,6 +150,9 @@ class GraphDataModule(LightningDataModule): data = [] if test: + cells = geometry.get("cells", None) + if cells is not None: + cells = torch.tensor(cells, dtype=torch.int64) data.append( MeshData( x=temperatures[0, :].unsqueeze(-1), @@ -158,6 +163,7 @@ class GraphDataModule(LightningDataModule): edge_attr=edge_attr, boundary_mask=boundary_mask, boundary_values=boundary_values, + cells=cells, ) ) return data