import torch from lightning import LightningModule from torch_geometric.data import Batch import importlib from matplotlib import pyplot as plt from matplotlib.tri import Triangulation from .model.finite_difference import FiniteDifferenceStep import os def import_class(class_path: str): module_path, class_name = class_path.rsplit(".", 1) # split last dot module = importlib.import_module(module_path) # import the module cls = getattr(module, class_name) # get the class return cls def _plot_mesh(pos, y, y_pred, batch, i, batch_idx): idx = batch == 0 y = y[idx].detach().cpu() y_pred = y_pred[idx].detach().cpu() pos = pos[idx].detach().cpu() folder = f"{batch_idx:02d}_images" if os.path.exists(folder) is False: os.makedirs(folder) pos = pos.detach().cpu() tria = Triangulation(pos[:, 0], pos[:, 1]) plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.tricontourf(tria, y.squeeze().numpy(), levels=14) plt.colorbar() plt.title("True temperature") plt.subplot(1, 3, 2) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) plt.colorbar() plt.title("Predicted temperature") plt.subplot(1, 3, 3) plt.tricontourf(tria, torch.abs(y_pred - y).squeeze().numpy(), levels=14) plt.colorbar() plt.title("Error") plt.suptitle("GNO", fontsize=16) name = f"{folder}/graph_iter_{i:04d}.png" plt.savefig(name, dpi=72) plt.close() def _plot_losses(losses, batch_idx): folder = f"{batch_idx:02d}_images" plt.figure() plt.plot(losses) plt.yscale("log") plt.xlabel("Iteration") plt.ylabel("Loss") plt.title("Test Loss over Iterations") plt.grid(True) file_name = f"{folder}/test_loss.png" plt.savefig(file_name, dpi=300) plt.close() class GraphSolver(LightningModule): def __init__( self, model_class_path: str, model_init_args: dict = {}, loss: torch.nn.Module = None, curriculum_learning: bool = False, start_iters: int = 10, increase_every: int = 100, increase_rate: float = 1.1, max_iters: int = 1000, accumulation_iters: int = None, ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) self.fd_net = FiniteDifferenceStep() self.loss = loss if loss is not None else torch.nn.MSELoss() self.curriculum_learning = curriculum_learning self.start_iters = start_iters self.increase_every = increase_every self.increase_rate = increase_rate self.max_iters = max_iters self.current_iters = start_iters self.accumulation_iters = accumulation_iters self.automatic_optimization = False self.threshold = 1e-5 self.alpha = torch.nn.Parameter(torch.tensor(0.1)) def _compute_deg(self, edge_index, edge_attr, num_nodes): deg = torch.zeros(num_nodes, device=edge_index.device) deg = deg.scatter_add(0, edge_index[1], edge_attr) return deg + 1e-7 def _compute_loss(self, x, y): return self.loss(x, y) def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) return loss @staticmethod def _compute_c_ij(c, edge_index): """ TODO: add docstring. """ return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() def _compute_model_steps( self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values ): # with torch.no_grad(): # out = self.fd_net(x, edge_index, edge_attr, deg) # out[boundary_mask] = boundary_values.unsqueeze(-1) # diff = out - x # out = self.model(out, edge_index, edge_attr, deg) # out = out + self.alpha * correction # out[boundary_mask] = boundary_values.unsqueeze(-1) out = self.model(x, edge_index, edge_attr, deg) out[boundary_mask] = boundary_values.unsqueeze(-1) return out def _check_convergence(self, out, x): residual_norm = torch.norm(out - x) if residual_norm < self.threshold * torch.norm(x): return True return False def accumulate_gradients(self, losses): loss_ = torch.stack(losses, dim=0).mean() self.manual_backward(loss_, retain_graph=True) return loss_.item() def _preprocess_batch(self, batch: Batch): x, y, c, edge_index, edge_attr = ( batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr, ) edge_attr = 1 / edge_attr c_ij = self._compute_c_ij(c, edge_index) edge_attr = edge_attr * c_ij return x, y, edge_index, edge_attr def training_step(self, batch: Batch, _): optim = self.optimizers() optim.zero_grad() x, y, edge_index, edge_attr = self._preprocess_batch(batch) deg = self._compute_deg(edge_index, edge_attr, x.size(0)) losses = [] acc_loss, acc_it = 0, 0 for i in range(self.current_iters): out = self._compute_model_steps( x, edge_index, edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, ) losses.append(self.loss(out, y)) # Accumulate gradients if reached accumulation iters if ( self.accumulation_iters is not None and (i + 1) % self.accumulation_iters == 0 ): loss = self.accumulate_gradients(losses) losses = [] acc_it += 1 out = out.detach() acc_loss = acc_loss + loss # Check for convergence and break if converged (with final accumulation) converged = self._check_convergence(out, x) if converged: if losses: loss = self.accumulate_gradients(losses) acc_it += 1 acc_loss = acc_loss + loss break # Final accumulation if we are at the last iteration if i == self.current_iters - 1: if losses: loss = self.accumulate_gradients(losses) acc_it += 1 acc_loss = acc_loss + loss x = out loss = self.loss(out, y) for param in self.model.parameters(): if param.grad is not None: param.grad /= acc_it optim.step() optim.zero_grad() self.log( "train/accumulated_loss", (acc_loss / acc_it if acc_it > 0 else acc_loss), on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) self.log( "train/iterations", i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) if hasattr(self.model, "p"): self.log( "train/p", self.model.p, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) def on_train_epoch_end(self): if self.curriculum_learning: if (self.current_iters < self.max_iters) and ( self.current_epoch % self.increase_every == 0 ): self.current_iters = min( int(self.current_iters * self.increase_rate), self.max_iters ) return super().on_train_epoch_end() def validation_step(self, batch: Batch, _): x, y, edge_index, edge_attr = self._preprocess_batch(batch) deg = self._compute_deg(edge_index, edge_attr, x.size(0)) for i in range(self.current_iters): out = self._compute_model_steps( x, edge_index, edge_attr.unsqueeze(-1), deg, batch.boundary_mask, batch.boundary_values, ) converged = self._check_convergence(out, x) if converged: break x = out loss = self.loss(out, y) self._log_loss(loss, batch, "val") self.log( "val/iterations", i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) def test_step(self, batch: Batch, batch_idx): pass def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) return optimizer def _impose_bc(self, x: torch.Tensor, data: Batch): x[data.boundary_mask] = data.boundary_values return x