import torch from lightning import LightningModule from torch_geometric.data import Batch import importlib from matplotlib import pyplot as plt from matplotlib.tri import Triangulation from .model.finite_difference import FiniteDifferenceStep import os def import_class(class_path: str): module_path, class_name = class_path.rsplit(".", 1) # split last dot module = importlib.import_module(module_path) # import the module cls = getattr(module, class_name) # get the class return cls def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, cells, i, batch_idx): # print(pos_.shape, y_.shape, y_pred_.shape, y_true_.shape) for j in [0]: idx = (batch == j).nonzero(as_tuple=True)[0] y = y_[idx].detach().cpu() y_pred = y_pred_[idx].detach().cpu() pos = pos_[idx].detach().cpu() # print(pos.shape, y.shape, y_pred.shape) y_true = y_true_[idx].detach().cpu() y_true = torch.clamp(y_true, min=0) folder = f"{batch_idx:02d}_images" if os.path.exists(folder) is False: os.makedirs(folder) triangles = torch.vstack([cells[:, [0, 1, 2]], cells[:, [0, 2, 3]]]) tria = Triangulation(pos[:, 0], pos[:, 1], triangles=triangles) plt.figure(figsize=(24, 6)) # plt.subplot(1, 4, 1) # plt.tricontourf(tria, y.squeeze().numpy(), levels=100) # plt.colorbar() # plt.title("Step t-1") # plt.tripcolor(tria, y_pred.squeeze().numpy() # plt.savefig("test_scatter_step_before.png", dpi=72) # x = z plt.subplot(1, 4, 1) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) # plt.scatter(pos[:, 0], pos[:, 1], c=y_pred.squeeze().numpy(), s=20, cmap="viridis",) plt.colorbar() plt.title(f"Prediction at timestep {i:03d}") plt.subplot(1, 4, 2) plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) # plt.scatter(pos[:, 0], pos[:, 1], c=y_true.squeeze().numpy(), s=20, cmap="viridis") plt.colorbar() plt.title("Ground Truth Steady State") plt.subplot(1, 4, 3) per_element_relative_error = torch.abs(y_pred - y_true) / ( y_true + 1e-6 ) per_element_relative_error = torch.clamp( per_element_relative_error, max=1.0, min=0.0 ) plt.tricontourf( tria, per_element_relative_error.squeeze(), levels=100, vmin=0, vmax=1.0, ) # plt.scatter(pos[:, 0], pos[:, 1], c=per_element_relative_error.squeeze().numpy(), s=20, cmap="viridis", vmin=0, vmax=1.0) plt.colorbar() plt.title("Relative Error") plt.subplot(1, 4, 4) absolute_error = torch.abs(y_pred - y_true) plt.tricontourf(tria, absolute_error.squeeze(), levels=100) # plt.scatter(pos[:, 0], pos[:, 1], c=absolute_error.squeeze().numpy(), s=20, cmap="viridis") plt.colorbar() plt.title("Absolute Error") plt.suptitle("GNO", fontsize=16) name = f"{folder}/{j:04d}_graph_iter_{i:04d}.png" plt.savefig(name, dpi=72) plt.close() def _plot_losses(relative_errors, test_losses, relative_update, batch_idx): # folder = f"{batch_idx:02d}_images" plt.figure(figsize=(18, 6)) plt.subplot(1, 3, 1) for i, losses in enumerate(test_losses): plt.plot(losses) if i == 3: break plt.yscale("log") plt.xlabel("Iteration") plt.ylabel("Test Loss") plt.title("Test Loss over Iterations") plt.grid(True) plt.subplot(1, 3, 2) for i, losses in enumerate(relative_errors): plt.plot(losses) if i == 3: break plt.yscale("log") plt.xlabel("Iteration") plt.ylabel("Relative Error") plt.title("Relative error over Iterations") plt.grid(True) plt.subplot(1, 3, 3) for i, updates in enumerate(relative_update): plt.plot(updates) if i == 3: break plt.yscale("log") plt.xlabel("Iteration") plt.ylabel("Relative Update") plt.title("Relative update over Iterations") plt.grid(True) file_name = f"test_errors.png" plt.savefig(file_name, dpi=300) plt.close() class GraphSolver(LightningModule): def __init__( self, model_class_path: str, model_init_args: dict = {}, loss: torch.nn.Module = None, unrolling_steps: int = 1, ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) # for param in self.model.parameters(): # print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param[0]}") self.loss = loss if loss is not None else torch.nn.MSELoss() self.unrolling_steps = unrolling_steps self.test_losses = [] self.test_relative_errors = [] self.test_relative_updates = [] def _compute_loss(self, x, y): return self.loss(x, y) def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) return loss @staticmethod def _compute_c_ij(c, edge_index): """ TODO: add docstring. """ return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() def _compute_model_steps( self, x, edge_index, edge_attr, boundary_mask, boundary_values, conductivity, ): out = self.model(x, edge_index, edge_attr, conductivity) out[boundary_mask] = boundary_values.unsqueeze(-1) return out def _preprocess_batch(self, batch: Batch): x, y, c, edge_index, edge_attr = ( batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr, ) edge_attr = 1 / edge_attr conductivity = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * conductivity return x, y, edge_index, edge_attr, conductivity def training_step(self, batch: Batch): x, y, edge_index, edge_attr, conductivity = self._preprocess_batch( batch ) losses = [] for i in range(self.unrolling_steps): # print(f"Training step {i+1}/{self.unrolling_steps}") out = self._compute_model_steps( x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, conductivity, ) x = out losses.append(self.loss(out.flatten(), y[:, i, :].flatten())) loss = torch.stack(losses).mean() self._log_loss(loss, batch, "train") for i, layer in enumerate(self.model.layers): self.log( f"{i:03d}_alpha", layer.alpha, prog_bar=True, on_epoch=True, on_step=False, batch_size=int(batch.num_graphs), ) self.log( "dt", self.model.dt, prog_bar=True, on_epoch=True, on_step=False, batch_size=int(batch.num_graphs), ) return loss def validation_step(self, batch: Batch, batch_idx): x, y, edge_index, edge_attr, conductivity = self._preprocess_batch( batch ) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] pos = batch.pos for i in range(self.unrolling_steps): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, conductivity, ) # if ( # batch_idx == 0 # and self.current_epoch % 10 == 0 # and self.current_epoch > 0 # ): # _plot_mesh( # batch.pos, # x, # out, # y[:, i, :], # batch.batch, # i, # self.current_epoch, # ) x = out losses.append(self.loss(out, y[:, i, :])) loss = torch.stack(losses).mean() self._log_loss(loss, batch, "val") return loss def _check_convergence(self, y_new, y_old, tol=1e-4): l2_norm = torch.norm(y_new - y_old, p=2) y_old_norm = torch.norm(y_old, p=2) rel_error = l2_norm / (y_old_norm) return rel_error.item() < tol, rel_error.item() def test_step(self, batch: Batch, batch_idx): x, y, edge_index, edge_attr, conductivity = self._preprocess_batch( batch ) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] all_losses = [] norms = [] s = [] relative_updates = [] sequence_length = y.size(1) y = y[:, -1, :].unsqueeze(1) _plot_mesh( batch.pos, x, x, y[:, -1, :], batch.batch, batch.cells, 0, batch_idx ) for i in range(200): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, conductivity, ) norms.append(torch.norm(out - x, p=2).item()) converged, relative_update = self._check_convergence(out, x) relative_updates.append(relative_update) if batch_idx <= 4: print(f"Plotting iteration {i}, norm diff: {norms[-1]}") _plot_mesh( batch.pos, x, out, y[:, -1, :], batch.batch, batch.cells, i + 1, batch_idx, ) x = out loss = self.loss(out, y[:, -1, :]) relative_error = torch.abs(out - y[:, -1, :]) / ( torch.abs(y[:, -1, :]) + 1e-6 ) mean_relative_error = relative_error.mean() all_losses.append(mean_relative_error.item()) losses.append(loss) if converged: print( f"Test step converged at iteration {i} for batch {batch_idx}" ) break loss = torch.stack(losses).mean() self.test_losses.append(losses) self.test_relative_errors.append(all_losses) self.test_relative_updates.append(relative_updates) self._log_loss(loss, batch, "test") return loss def on_test_end(self): if len(self.test_losses) > 0: _plot_losses( self.test_relative_errors, self.test_losses, self.test_relative_updates, batch_idx=0, ) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) return optimizer