diff --git a/ThermalSolver/autoregressive_module.py b/ThermalSolver/autoregressive_module.py index f5254dd..2247bd1 100644 --- a/ThermalSolver/autoregressive_module.py +++ b/ThermalSolver/autoregressive_module.py @@ -17,7 +17,7 @@ def import_class(class_path: str): def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): # print(pos_.shape, y_.shape, y_pred_.shape, y_true_.shape) - for j in [0, 5, 10, 20]: + for j in [0]: idx = (batch == j).nonzero(as_tuple=True)[0] y = y_[idx].detach().cpu() y_pred = y_pred_[idx].detach().cpu() @@ -25,11 +25,11 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): # print(pos.shape, y.shape, y_pred.shape) y_true = y_true_[idx].detach().cpu() y_true = torch.clamp(y_true, min=0) - folder = f"{j:02d}_images" + folder = f"{batch_idx:02d}_images" if os.path.exists(folder) is False: os.makedirs(folder) tria = Triangulation(pos[:, 0], pos[:, 1]) - plt.figure(figsize=(18, 6)) + plt.figure(figsize=(24, 6)) # plt.subplot(1, 4, 1) # plt.tricontourf(tria, y.squeeze().numpy(), levels=100) # plt.colorbar() @@ -37,59 +37,94 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): # plt.tripcolor(tria, y_pred.squeeze().numpy() # plt.savefig("test_scatter_step_before.png", dpi=72) # x = z - plt.subplot(1, 3, 1) - plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) - # plt.scatter( - # pos[:, 0], - # pos[:, 1], - # c=y_pred.squeeze().numpy(), - # s=20, - # cmap="viridis", - # ) + plt.subplot(1, 4, 1) + # plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) + plt.scatter( + pos[:, 0], + pos[:, 1], + c=y_pred.squeeze().numpy(), + s=20, + cmap="viridis", + ) plt.colorbar() - plt.title("Step t Predicted") - plt.subplot(1, 3, 2) - plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) - # plt.scatter( - # pos[:, 0], - # pos[:, 1], - # c=y_true.squeeze().numpy(), - # s=20, - # cmap="viridis", - # ) + plt.title(f"Prediction at timestep {i:03d}") + plt.subplot(1, 4, 2) + # plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) + plt.scatter( + pos[:, 0], + pos[:, 1], + c=y_true.squeeze().numpy(), + s=20, + cmap="viridis", + ) plt.colorbar() - plt.title("t True") - plt.subplot(1, 3, 3) - per_element_relative_error = torch.abs(y_pred - y_true) - plt.tricontourf(tria, per_element_relative_error.squeeze(), levels=100) - # plt.scatter( - # pos[:, 0], - # pos[:, 1], - # c=per_element_relative_error.squeeze().numpy(), - # s=20, - # cmap="viridis", - # ) + plt.title("Ground Truth Steady State") + plt.subplot(1, 4, 3) + per_element_relative_error = torch.abs(y_pred - y_true) / (y_true + 1e-6) + # plt.tricontourf(tria, per_element_relative_error.squeeze(), levels=100) + plt.scatter( + pos[:, 0], + pos[:, 1], + c=per_element_relative_error.squeeze().numpy(), + s=20, + cmap="viridis", + vmin=0, + vmax=1.0, + ) plt.colorbar() plt.title("Relative Error") + plt.subplot(1, 4, 4) + absolute_error = torch.abs(y_pred - y_true) + # plt.tricontourf(tria, absolute_error.squeeze(), levels=100) + plt.scatter( + pos[:, 0], + pos[:, 1], + c=absolute_error.squeeze().numpy(), + s=20, + cmap="viridis", + ) + plt.colorbar() + plt.title("Absolute Error") plt.suptitle("GNO", fontsize=16) name = f"{folder}/{j:04d}_graph_iter_{i:04d}.png" plt.savefig(name, dpi=72) plt.close() -def _plot_losses(test_losses, batch_idx): - folder = f"{batch_idx:02d}_images" - plt.figure() +def _plot_losses(relative_errors, test_losses, relative_update, batch_idx): + # folder = f"{batch_idx:02d}_images" + plt.figure(figsize=(18, 6)) + plt.subplot(1, 3, 1) for i, losses in enumerate(test_losses): plt.plot(losses) if i == 3: break plt.yscale("log") plt.xlabel("Iteration") - plt.ylabel("Relative Error") + plt.ylabel("Test Loss") plt.title("Test Loss over Iterations") plt.grid(True) - file_name = f"{folder}/test_loss.png" + plt.subplot(1, 3, 2) + for i, losses in enumerate(relative_errors): + plt.plot(losses) + if i == 3: + break + plt.yscale("log") + plt.xlabel("Iteration") + plt.ylabel("Relative Error") + plt.title("Relative error over Iterations") + plt.grid(True) + plt.subplot(1, 3, 3) + for i, updates in enumerate(relative_update): + plt.plot(updates) + if i == 3: + break + plt.yscale("log") + plt.xlabel("Iteration") + plt.ylabel("Relative Update") + plt.title("Relative update over Iterations") + plt.grid(True) + file_name = f"test_errors.png" plt.savefig(file_name, dpi=300) plt.close() @@ -110,6 +145,8 @@ class GraphSolver(LightningModule): self.loss = loss if loss is not None else torch.nn.MSELoss() self.unrolling_steps = unrolling_steps self.test_losses = [] + self.test_relative_errors = [] + self.test_relative_updates = [] def _compute_loss(self, x, y): return self.loss(x, y) @@ -164,6 +201,7 @@ class GraphSolver(LightningModule): ) losses = [] for i in range(self.unrolling_steps): + # print(f"Training step {i+1}/{self.unrolling_steps}") out = self._compute_model_steps( x, edge_index, @@ -235,11 +273,11 @@ class GraphSolver(LightningModule): self._log_loss(loss, batch, "val") return loss - def _check_convergence(self, y_new, y_old, tol=1e-3): - l2_norm = torch.norm(y_new, p=2) - torch.norm(y_old, p=2) + def _check_convergence(self, y_new, y_old, tol=1e-4): + l2_norm = torch.norm(y_new - y_old, p=2) y_old_norm = torch.norm(y_old, p=2) rel_error = l2_norm / (y_old_norm) - return rel_error.item() < tol + return rel_error.item() < tol, rel_error.item() def test_step(self, batch: Batch, batch_idx): x, y, edge_index, edge_attr, conductivity = self._preprocess_batch( @@ -249,8 +287,19 @@ class GraphSolver(LightningModule): losses = [] all_losses = [] norms = [] + s = [] + relative_updates = [] sequence_length = y.size(1) y = y[:, -1, :].unsqueeze(1) + _plot_mesh( + batch.pos, + x, + x, + y[:, -1, :], + batch.batch, + 0, + batch_idx + ) for i in range(100): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), @@ -263,23 +312,24 @@ class GraphSolver(LightningModule): conductivity, ) norms.append(torch.norm(out - x, p=2).item()) - converged = self._check_convergence(out, x) - if batch_idx == 0: + converged, relative_update = self._check_convergence(out, x) + relative_updates.append(relative_update) + if batch_idx <= 4: + print(f"Plotting iteration {i}, norm diff: {norms[-1]}") _plot_mesh( batch.pos, x, out, y[:, -1, :], batch.batch, - i, - self.current_epoch, + i+1, + batch_idx ) x = out loss = self.loss(out, y[:, -1, :]) - relative_error = torch.norm(out - y[:, -1, :], p=2) / torch.norm( - y[:, -1, :], p=2 - ) - all_losses.append(relative_error.item()) + relative_error = torch.abs(out - y[:, -1, :]) / (torch.abs(y[:, -1, :]) + 1e-6) + mean_relative_error = relative_error.mean() + all_losses.append(mean_relative_error.item()) losses.append(loss) if converged: print( @@ -287,13 +337,15 @@ class GraphSolver(LightningModule): ) break loss = torch.stack(losses).mean() - self.test_losses.append(all_losses) + self.test_losses.append(losses) + self.test_relative_errors.append(all_losses) + self.test_relative_updates.append(relative_updates) self._log_loss(loss, batch, "test") return loss def on_test_end(self): if len(self.test_losses) > 0: - _plot_losses(self.test_losses, batch_idx=0) + _plot_losses(self.test_relative_errors, self.test_losses, self.test_relative_updates, batch_idx=0) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) diff --git a/ThermalSolver/graph_datamodule.py b/ThermalSolver/graph_datamodule.py index 564dee6..53b40a7 100644 --- a/ThermalSolver/graph_datamodule.py +++ b/ThermalSolver/graph_datamodule.py @@ -82,7 +82,7 @@ class GraphDataModule(LightningDataModule): conductivity = torch.tensor( snapshot["conductivity"], dtype=torch.float32 ) - temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32) + temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)[:50] pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] diff --git a/ThermalSolver/graph_datamodule_unsteady.py b/ThermalSolver/graph_datamodule_unsteady.py index 9ec9df2..e3c69cc 100644 --- a/ThermalSolver/graph_datamodule_unsteady.py +++ b/ThermalSolver/graph_datamodule_unsteady.py @@ -23,6 +23,8 @@ class GraphDataModule(LightningDataModule): build_radial_graph: bool = False, radius: float = None, unrolling_steps: int = 1, + aggregate_timesteps: int = 1, + min_normalized_diff: float = 1e-3, ): super().__init__() self.hf_repo = hf_repo @@ -35,6 +37,9 @@ class GraphDataModule(LightningDataModule): None, ) self.unrolling_steps = unrolling_steps + self.aggregate_timesteps = aggregate_timesteps + self.min_normalized_diff = min_normalized_diff + self.geometry_dict = {} self.train_size = train_size self.val_size = val_size @@ -109,10 +114,14 @@ class GraphDataModule(LightningDataModule): dim=0, ) ) - # print(temperatures.shape) - + if not test: + for t in range(1, temperatures.size(0)): + diff = temperatures[t, :] - temperatures[t - 1, :] + norm_diff = torch.norm(diff, p=2) / torch.norm(temperatures[t - 1], p=2) + if norm_diff < self.min_normalized_diff: + temperatures = temperatures[: t + 1, :] + break pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] - if self.build_radial_graph: raise NotImplementedError( "Radial graph building not implemented yet." @@ -224,7 +233,7 @@ class GraphDataModule(LightningDataModule): batch_size=self.batch_size, shuffle=True, num_workers=8, - pin_memory=True, + pin_memory=False, ) def val_dataloader(self): @@ -237,7 +246,7 @@ class GraphDataModule(LightningDataModule): batch_size=128, shuffle=False, num_workers=8, - pin_memory=True, + pin_memory=False, ) def test_dataloader(self): @@ -247,5 +256,5 @@ class GraphDataModule(LightningDataModule): batch_size=1, shuffle=False, num_workers=8, - pin_memory=True, + pin_memory=False, ) diff --git a/ThermalSolver/model/diffusion_net.py b/ThermalSolver/model/diffusion_net.py index 644b7c5..9c9ae09 100644 --- a/ThermalSolver/model/diffusion_net.py +++ b/ThermalSolver/model/diffusion_net.py @@ -123,3 +123,4 @@ class DiffusionNet(nn.Module): # 6. Final Update (Explicit Euler Step) # T_new = T_old + Correction return delta_x + x_input * self.dt + # return delta_x