import torch from lightning import LightningModule from torch_geometric.data import Batch import importlib from matplotlib import pyplot as plt from matplotlib.tri import Triangulation def import_class(class_path: str): module_path, class_name = class_path.rsplit(".", 1) # split last dot module = importlib.import_module(module_path) # import the module cls = getattr(module, class_name) # get the class return cls def _plot_mesh(pos, y, y_pred, batch, i): idx = batch == 0 y = y[idx].detach().cpu() y_pred = y_pred[idx].detach().cpu() pos = pos[idx].detach().cpu() pos = pos.detach().cpu() tria = Triangulation(pos[:, 0], pos[:, 1]) plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.tricontourf(tria, y.squeeze().numpy(), levels=14) plt.colorbar() plt.title("True temperature") plt.subplot(1, 3, 2) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) plt.colorbar() plt.title("Predicted temperature") plt.subplot(1, 3, 3) plt.tricontourf(tria, torch.abs(y_pred - y).squeeze().numpy(), levels=14) plt.colorbar() plt.title("Error") plt.suptitle("GNO", fontsize=16) name = f"images/graph_iter_{i:04d}.png" plt.savefig(name, dpi=72) plt.close() class GraphSolver(LightningModule): def __init__( self, model_class_path: str, model_init_args: dict = {}, loss: torch.nn.Module = None, curriculum_learning: bool = False, start_iters: int = 10, increase_every: int = 100, increase_rate: float = 1.1, max_iters: int = 1000, accumulation_iters: int = None, ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) self.loss = loss if loss is not None else torch.nn.MSELoss() self.curriculum_learning = curriculum_learning self.start_iters = start_iters self.increase_every = increase_every self.increase_rate = increase_rate self.max_iters = max_iters self.current_iters = start_iters self.accumulation_iters = accumulation_iters self.automatic_optimization = False self.threshold = 1e-5 def _compute_deg(self, edge_index, edge_attr, num_nodes): deg = torch.zeros(num_nodes, device=edge_index.device) deg = deg.scatter_add(0, edge_index[1], edge_attr) return deg + 1e-7 def _compute_loss(self, x, y): return self.loss(x, y) def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) return loss @staticmethod def _compute_c_ij(c, edge_index): """ TODO: add docstring. """ return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() def _compute_model_steps( self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values ): out = self.model(x, edge_index, edge_attr, deg) out[boundary_mask] = boundary_values.unsqueeze(-1) return out def _check_convergence(self, out, x): residual_norm = torch.norm(out - x) if residual_norm < self.threshold * torch.norm(x): return True return False def accumulate_gradients(self, losses, max_acc_iters): loss_ = torch.stack(losses, dim=0).mean() loss = 0.5 * loss_ / self.accumulation_iters self.manual_backward(loss / max_acc_iters, retain_graph=True) return loss_.item() def _preprocess_batch(self, batch: Batch): x, y, c, edge_index, edge_attr = ( batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr, ) edge_attr = 1 / edge_attr c_ij = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * c_ij return x, y, edge_index, edge_attr def training_step(self, batch: Batch, batch_idx: int): optim = self.optimizers() optim.zero_grad() x, y, edge_index, edge_attr = self._preprocess_batch(batch) deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] acc_loss, acc_it = 0, 0 max_acc_iters = ( self.current_iters // self.accumulation_iters + 1 if self.accumulation_iters is not None else 1 ) for i in range(self.current_iters): out = self._compute_model_steps( x, edge_index, edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, ) losses.append(self.loss(out, y)) # Accumulate gradients if reached accumulation iters if ( self.accumulation_iters is not None and (i + 1) % self.accumulation_iters == 0 ): loss = self.accumulate_gradients(losses, max_acc_iters) losses = [] acc_it += 1 out = out.detach() acc_loss = acc_loss + loss # Check for convergence and break if converged (with final accumulation) converged = self._check_convergence(out, x) if converged: if losses: loss = self.accumulate_gradients(losses, max_acc_iters) acc_it += 1 acc_loss = acc_loss + loss break # Final accumulation if we are at the last iteration if i == self.current_iters - 1: if losses: loss = self.accumulate_gradients(losses, max_acc_iters) acc_it += 1 acc_loss = acc_loss + loss x = out if i % self.accumulation_iters != 0: loss = self.loss(out, y) loss.backward() optim.step() optim.zero_grad() self._log_loss(acc_loss / acc_it, batch, "train") self.log( "train/iterations", i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) if hasattr(self.model, "p"): self.log( "train/p", self.model.p, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) def on_train_epoch_end(self): if self.curriculum_learning: if (self.current_iters < self.max_iters) and ( self.current_epoch % self.increase_every == 0 ): self.current_iters = min( int(self.current_iters * self.increase_rate), self.max_iters ) return super().on_train_epoch_end() def validation_step(self, batch: Batch, _): x, y, edge_index, edge_attr = self._preprocess_batch(batch) deg = self._compute_deg(edge_index, edge_attr, x.size(0)) for i in range(self.current_iters): out = self._compute_model_steps( x, edge_index, edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, ) converged = self._check_convergence(out, x) if converged: break x = out loss = 0.5 * self.loss(out, y) self._log_loss(loss, batch, "val") self.log( "val/iterations", i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) def test_step(self, batch: Batch, _): x, y, edge_index, edge_attr = self._preprocess_batch(batch) deg = self._compute_deg(edge_index, edge_attr, x.size(0)) for i in range(self.max_iters): out = self._compute_model_steps( x, edge_index, edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, ) converged = self._check_convergence(out, x) # _plot_mesh(batch.pos, y, out, batch.batch, i) if converged: break x = out loss = self.loss(out, y) self._log_loss(loss, batch, "test") self.log( "test/iterations", i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) return optimizer def _impose_bc(self, x: torch.Tensor, data: Batch): x[data.boundary_mask] = data.boundary_values return x