import torch from lightning import LightningModule from torch_geometric.data import Batch import importlib from matplotlib import pyplot as plt from matplotlib.tri import Triangulation from .model.finite_difference import FiniteDifferenceStep import os def import_class(class_path: str): module_path, class_name = class_path.rsplit(".", 1) # split last dot module = importlib.import_module(module_path) # import the module cls = getattr(module, class_name) # get the class return cls def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): for j in [0, 10, 20, 30]: idx = (batch == j).nonzero(as_tuple=True)[0] y = y_[idx].detach().cpu() y_pred = y_pred_[idx].detach().cpu() pos = pos_[idx].detach().cpu() y_true = y_true_[idx].detach().cpu() y_true = torch.clamp(y_true, min=0) folder = f"{j:02d}_images" if os.path.exists(folder) is False: os.makedirs(folder) pos = pos.detach().cpu() tria = Triangulation(pos[:, 0], pos[:, 1]) plt.figure(figsize=(24, 5)) plt.subplot(1, 4, 1) plt.tricontourf(tria, y.squeeze().numpy(), levels=100) plt.colorbar() plt.title("Step t-1") plt.subplot(1, 4, 2) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) plt.colorbar() plt.title("Step t Predicted") plt.subplot(1, 4, 3) plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) plt.colorbar() plt.title("t True") plt.subplot(1, 4, 4) plt.tricontourf(tria, (y_true - y_pred).squeeze().numpy(), levels=100) plt.colorbar() plt.title("Error") plt.suptitle("GNO", fontsize=16) name = f"{folder}/{j:04d}_graph_iter_{i:04d}.png" plt.savefig(name, dpi=72) plt.close() def _plot_losses(losses, batch_idx): folder = f"{batch_idx:02d}_images" plt.figure() plt.plot(losses) plt.yscale("log") plt.xlabel("Iteration") plt.ylabel("Loss") plt.title("Test Loss over Iterations") plt.grid(True) file_name = f"{folder}/test_loss.png" plt.savefig(file_name, dpi=300) plt.close() class GraphSolver(LightningModule): def __init__( self, model_class_path: str, model_init_args: dict = {}, loss: torch.nn.Module = None, unrolling_steps: int = 1, ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) # for param in self.model.parameters(): # print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param[0]}") self.loss = loss if loss is not None else torch.nn.MSELoss() self.unrolling_steps = unrolling_steps def _compute_loss(self, x, y): return self.loss(x, y) def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) return loss @staticmethod def _compute_c_ij(c, edge_index): """ TODO: add docstring. """ return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() def _compute_model_steps( self, x, edge_index, edge_attr, boundary_mask, boundary_values, conductivity, ): out = self.model(x, edge_index, edge_attr, conductivity) out[boundary_mask] = boundary_values.unsqueeze(-1) return out def _preprocess_batch(self, batch: Batch): x, y, c, edge_index, edge_attr = ( batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr, ) edge_attr = 1 / edge_attr conductivity = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * conductivity return x, y, edge_index, edge_attr, conductivity def training_step(self, batch: Batch): x, y, edge_index, edge_attr, conductivity = self._preprocess_batch( batch ) losses = [] for i in range(self.unrolling_steps): out = self._compute_model_steps( x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, conductivity, ) x = out losses.append(self.loss(out.flatten(), y[:, i, :].flatten())) loss = torch.stack(losses).mean() self._log_loss(loss, batch, "train") for i, layer in enumerate(self.model.layers): self.log( f"alpha_{i}", layer.alpha, prog_bar=True, on_epoch=True, on_step=False, batch_size=int(batch.num_graphs), ) self.log( "dt", self.model.dt, prog_bar=True, on_epoch=True, on_step=False, batch_size=int(batch.num_graphs), ) return loss def validation_step(self, batch: Batch, batch_idx): x, y, edge_index, edge_attr, conductivity = self._preprocess_batch( batch ) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] pos = batch.pos for i in range(self.unrolling_steps): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, conductivity, ) if ( batch_idx == 0 and self.current_epoch % 10 == 0 and self.current_epoch > 0 ): _plot_mesh( batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch, ) x = out losses.append(self.loss(out, y[:, i, :])) loss = torch.stack(losses).mean() self._log_loss(loss, batch, "val") return loss def _check_convergence(self, y_pred, y_true, tol=1e-3): l2_norm = torch.norm(y_pred - y_true, p=2) y_true_norm = torch.norm(y_true, p=2) rel_error = l2_norm / (y_true_norm + 1e-8) return rel_error.item() < tol def test_step(self, batch: Batch, batch_idx): x, y, edge_index, edge_attr, conductivity = self._preprocess_batch( batch ) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] all_losses = [] norms = [] for i in range(self.unrolling_steps): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, conductivity, ) norms.append(torch.norm(out - x, p=2).item()) x = out loss = self.loss(out, y[:, i, :]) all_losses.append(loss.item()) losses.append(loss) # if ( # batch_idx == 0 # and self.current_epoch % 10 == 0 # and self.current_epoch > 0 # ): # _plot_mesh( # batch.pos, # x, # out, # y[:, i, :], # batch.batch, # i, # self.current_epoch, # ) loss = torch.stack(losses).mean() # if ( # batch_idx == 0 # and self.current_epoch % 10 == 0 # and self.current_epoch > 0 # ): _plot_losses(norms, self.current_epoch) self._log_loss(loss, batch, "test") return loss def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) return optimizer