import torch from lightning import LightningModule from torch_geometric.data import Batch import importlib from matplotlib import pyplot as plt from matplotlib.tri import Triangulation from .model.finite_difference import FiniteDifferenceStep import os def import_class(class_path: str): module_path, class_name = class_path.rsplit(".", 1) # split last dot module = importlib.import_module(module_path) # import the module cls = getattr(module, class_name) # get the class return cls def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx): for j in [0, 10, 20, 30]: idx = (batch == j).nonzero(as_tuple=True)[0] y = y_[idx].detach().cpu() y_pred = y_pred_[idx].detach().cpu() pos = pos_[idx].detach().cpu() y_true = y_true_[idx].detach().cpu() y_true = torch.clamp(y_true, min=0) folder = f"{j:02d}_images" if os.path.exists(folder) is False: os.makedirs(folder) pos = pos.detach().cpu() tria = Triangulation(pos[:, 0], pos[:, 1]) plt.figure(figsize=(24, 5)) plt.subplot(1, 4, 1) plt.tricontourf(tria, y.squeeze().numpy(), levels=100) plt.colorbar() plt.title("Step t-1") plt.subplot(1, 4, 2) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) plt.colorbar() plt.title("Step t Predicted") plt.subplot(1, 4, 3) plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) plt.colorbar() plt.title("t True") plt.subplot(1, 4, 4) plt.tricontourf(tria, (y_true - y_pred).squeeze().numpy(), levels=100) plt.colorbar() plt.title("Error") plt.suptitle("GNO", fontsize=16) name = f"{folder}/{j:04d}_graph_iter_{i:04d}.png" plt.savefig(name, dpi=72) plt.close() def _plot_losses(losses, batch_idx): folder = f"{batch_idx:02d}_images" plt.figure() plt.plot(losses) plt.yscale("log") plt.xlabel("Iteration") plt.ylabel("Loss") plt.title("Test Loss over Iterations") plt.grid(True) file_name = f"{folder}/test_loss.png" plt.savefig(file_name, dpi=300) plt.close() class GraphSolver(LightningModule): def __init__( self, model_class_path: str, model_init_args: dict = {}, loss: torch.nn.Module = None, unrolling_steps: int = 1, ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) # for param in self.model.parameters(): # print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param[0]}") self.loss = loss if loss is not None else torch.nn.MSELoss() self.unrolling_steps = unrolling_steps def _compute_loss(self, x, y): return self.loss(x, y) def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) return loss @staticmethod def _compute_c_ij(c, edge_index): """ TODO: add docstring. """ return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() def _compute_model_steps( self, x, edge_index, edge_attr, boundary_mask, boundary_values ): out = self.model(x, edge_index, edge_attr) out[boundary_mask] = boundary_values.unsqueeze(-1) # print(torch.min(out), torch.max(out)) return out def _preprocess_batch(self, batch: Batch): x, y, c, edge_index, edge_attr = ( batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr, ) edge_attr = 1 / edge_attr c_ij = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * c_ij # edge_attr = edge_attr / torch.max(edge_attr) return x, y, edge_index, edge_attr def training_step(self, batch: Batch): x, y, edge_index, edge_attr = self._preprocess_batch(batch) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] # print(x.shape, y.shape) # # print(torch.max(edge_index), torch.min(edge_index)) # plt.figure() # plt.subplot(2,3,1) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=x.squeeze().cpu()) # plt.subplot(2,3,2) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,0,:].squeeze().cpu()) # plt.subplot(2,3,3) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,1,:].squeeze().cpu()) # plt.subplot(2,3,4) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,2,:].squeeze().cpu()) # plt.subplot(2,3,5) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,3,:].squeeze().cpu()) # plt.subplot(2,3,6) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,4,:].squeeze().cpu()) # plt.suptitle("Training Batch Visualization", fontsize=16) # plt.savefig("training_batch_visualization.png", dpi=300) # plt.close() # y = z pos = batch.pos boundary_mask = batch.boundary_mask boundary_values = batch.boundary_values # plt.scatter(pos[boundary_mask,0].cpu(), pos[boundary_mask,1].cpu(), c=boundary_values.cpu(), s=1) # plt.savefig("boundary_nodes.png", dpi=300) # y = z scale = 50 for i in range(self.unrolling_steps): out = self._compute_model_steps( x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, ) x = out # print(out.shape, y[:, i, :].shape) losses.append(self.loss(out.flatten(), y[:, i, :].flatten())) # print(self.model.scale_edge_attr.item()) loss = torch.stack(losses).mean() # for param in self.model.parameters(): # print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param[0]}") self._log_loss(loss, batch, "train") return loss def validation_step(self, batch: Batch, batch_idx): x, y, edge_index, edge_attr = self._preprocess_batch(batch) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] pos = batch.pos for i in range(self.unrolling_steps): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, ) if (batch_idx == 0 and self.current_epoch % 10 == 0 and self.current_epoch > 20): _plot_mesh(batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch) x = out losses.append(self.loss(out , y[:, i, :])) loss = torch.stack(losses).mean() self._log_loss(loss, batch, "val") return loss def test_step(self, batch: Batch, batch_idx): pass def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=5e-3) return optimizer