import torch from lightning import LightningModule from torch_geometric.data import Batch import importlib from matplotlib import pyplot as plt from matplotlib.tri import Triangulation from .model.finite_difference import FiniteDifferenceStep import os def import_class(class_path: str): module_path, class_name = class_path.rsplit(".", 1) # split last dot module = importlib.import_module(module_path) # import the module cls = getattr(module, class_name) # get the class return cls def _plot_mesh(pos, y, y_pred, y_true ,batch, i, batch_idx): idx = batch == 0 y = y[idx].detach().cpu() y_pred = y_pred[idx].detach().cpu() pos = pos[idx].detach().cpu() y_true = y_true[idx].detach().cpu() # print(torch.max(y_true), torch.min(y_true)) folder = f"{batch_idx:02d}_images" if os.path.exists(folder) is False: os.makedirs(folder) pos = pos.detach().cpu() tria = Triangulation(pos[:, 0], pos[:, 1]) plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.tricontourf(tria, y.squeeze().numpy(), levels=14) plt.colorbar() plt.title("Step t-1") plt.subplot(1, 3, 2) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) plt.colorbar() plt.title("Step t Predicted") plt.subplot(1, 3, 3) plt.tricontourf(tria, y_true.squeeze().numpy(), levels=14) plt.colorbar() plt.title("t True") plt.suptitle("GNO", fontsize=16) name = f"{folder}/graph_iter_{i:04d}.png" plt.savefig(name, dpi=72) plt.close() def _plot_losses(losses, batch_idx): folder = f"{batch_idx:02d}_images" plt.figure() plt.plot(losses) plt.yscale("log") plt.xlabel("Iteration") plt.ylabel("Loss") plt.title("Test Loss over Iterations") plt.grid(True) file_name = f"{folder}/test_loss.png" plt.savefig(file_name, dpi=300) plt.close() class GraphSolver(LightningModule): def __init__( self, model_class_path: str, model_init_args: dict = {}, loss: torch.nn.Module = None, start_unrolling_steps: int = 1, increase_every: int = 20, increase_rate: float = 2, max_unrolling_steps: int = 100, max_inference_iters: int = 1000, inner_steps: int = 16, ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) # for param in self.model.parameters(): # print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param[0]}") self.fd_net = FiniteDifferenceStep() self.loss = loss if loss is not None else torch.nn.MSELoss() self.start_unrolling = start_unrolling_steps self.current_unrolling_steps = self.start_unrolling self.increase_every = increase_every self.increase_rate = increase_rate self.max_unrolling_steps = max_unrolling_steps self.max_inference_iters = max_inference_iters self.threshold = 1e-4 self.inner_steps = inner_steps def _compute_deg(self, edge_index, edge_attr, num_nodes): deg = torch.zeros(num_nodes, device=edge_index.device) deg = deg.scatter_add(0, edge_index[1], edge_attr) return deg + 1e-7 def _compute_loss(self, x, y): return self.loss(x, y) def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) return loss @staticmethod def _compute_c_ij(c, edge_index): """ TODO: add docstring. """ return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() def _compute_model_steps( self, x, edge_index, edge_attr, boundary_mask, boundary_values ): out = x + self.model(x, edge_index, edge_attr) # out[boundary_mask] = boundary_values.unsqueeze(-1) plt.figure() return out def _check_convergence(self, out, x): residual_norm = torch.norm(out - x) if residual_norm < self.threshold * torch.norm(x): return True return False def _preprocess_batch(self, batch: Batch): x, y, c, edge_index, edge_attr = ( batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr, ) # edge_attr = 1 / edge_attr c_ij = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * (c_ij) # / 100) return x, y, edge_index, edge_attr def training_step(self, batch: Batch): x, y, edge_index, edge_attr = self._preprocess_batch(batch) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] # print(x.shape, y.shape) # # print(torch.max(edge_index), torch.min(edge_index)) # plt.figure() # plt.subplot(2,3,1) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=x.squeeze().cpu()) # plt.subplot(2,3,2) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,0,:].squeeze().cpu()) # plt.subplot(2,3,3) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,1,:].squeeze().cpu()) # plt.subplot(2,3,4) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,2,:].squeeze().cpu()) # plt.subplot(2,3,5) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,3,:].squeeze().cpu()) # plt.subplot(2,3,6) # plt.scatter(batch.pos[:,0].cpu(), batch.pos[:,1].cpu(), c=y[:,4,:].squeeze().cpu()) # plt.suptitle("Training Batch Visualization", fontsize=16) # plt.savefig("training_batch_visualization.png", dpi=300) # plt.close() # y = z pos = batch.pos boundary_mask = batch.boundary_mask boundary_values = batch.boundary_values # plt.scatter(pos[boundary_mask,0].cpu(), pos[boundary_mask,1].cpu(), c=boundary_values.cpu(), s=1) # plt.savefig("boundary_nodes.png", dpi=300) # y = z print(y.shape) for i in range(self.current_unrolling_steps * self.inner_steps): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, ) x = out # print(out.shape, y[:, i, :].shape) losses.append(self.loss(out.flatten(), y[:, i, :].flatten())) print(losses) loss = torch.stack(losses).mean() # for param in self.model.parameters(): # print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param[0]}") self._log_loss(loss, batch, "train") return loss # def on_train_epoch_start(self): # print(f"Current unrolling steps: {self.current_unrolling_steps}, dataset unrolling steps: {self.trainer.datamodule.train_dataset.unrolling_steps}") # return super().on_train_epoch_start() def on_train_epoch_end(self): if ( (self.current_epoch + 1) % self.increase_every == 0 and self.current_epoch > 0 ): dm = self.trainer.datamodule self.current_unrolling_steps = min( int(self.current_unrolling_steps * self.increase_rate), self.max_unrolling_steps ) dm.unrolling_steps = self.current_unrolling_steps return super().on_train_epoch_end() def validation_step(self, batch: Batch, _): # x, y, edge_index, edge_attr = self._preprocess_batch(batch) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) # for i in range(self.max_inference_iters * self.inner_steps): # out = self._compute_model_steps( # x, # edge_index, # edge_attr, # deg, # batch.boundary_mask, # batch.boundary_values, # ) # converged = self._check_convergence(out, x) # x = out # if converged: # break # print(y.shape, out.shape) # loss = self.loss(out, y[:,-1,:]) # self._log_loss(loss, batch, "val") # self.log("val/iterations", i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs),) # return loss x, y, edge_index, edge_attr = self._preprocess_batch(batch) # deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] pos = batch.pos for i in range(self.current_unrolling_steps * self.inner_steps): out = self._compute_model_steps( # torch.cat([x,pos], dim=-1), x, edge_index, edge_attr, # deg, batch.boundary_mask, batch.boundary_values, ) _plot_mesh(batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch) x = out losses.append(self.loss(out, y[:, i, :])) loss = torch.stack(losses).mean() self._log_loss(loss, batch, "val") return loss def test_step(self, batch: Batch, batch_idx): x, y, edge_index, edge_attr = self._preprocess_batch(batch) deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] for i in range(self.max_iters): out = self._compute_model_steps( x, edge_index, edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, ) converged = self._check_convergence(out, x) # _plot_mesh(batch.pos, y, out, batch.batch, i, batch_idx) losses.append(self.loss(out, y).item()) if converged: break x = out loss = self.loss(out, y) # _plot_losses(losses, batch_idx) self._log_loss(loss, batch, "test") self.log( "test/iterations", i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-2) return optimizer def _impose_bc(self, x: torch.Tensor, data: Batch): x[data.boundary_mask] = data.boundary_values return x