diff --git a/ThermalSolver/graph_module.py b/ThermalSolver/graph_module.py index 90cc8b0..bf80043 100644 --- a/ThermalSolver/graph_module.py +++ b/ThermalSolver/graph_module.py @@ -13,7 +13,7 @@ def import_class(class_path: str): return cls -def _plot_mesh(pos, y, y_pred, batch): +def _plot_mesh(pos, y, y_pred, batch, i): idx = batch == 0 y = y[idx].detach().cpu() @@ -36,41 +36,41 @@ def _plot_mesh(pos, y, y_pred, batch): plt.colorbar() plt.title("Error") plt.suptitle("GNO", fontsize=16) - plt.savefig("gno.png", dpi=300) + name = f"images/graph_iter_{i:04d}.png" + plt.savefig(name, dpi=72) + plt.close() class GraphSolver(LightningModule): def __init__( self, model_class_path: str, - model_init_args: dict, + model_init_args: dict = {}, loss: torch.nn.Module = None, - unrolling_steps: int = 48, + curriculum_learning: bool = False, + start_iters: int = 10, + increase_every: int = 100, + increase_rate: float = 1.1, + max_iters: int = 1000, + accumulation_iters: int = None, ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) self.loss = loss if loss is not None else torch.nn.MSELoss() - self.unrolling_steps = unrolling_steps + self.curriculum_learning = curriculum_learning + self.start_iters = start_iters + self.increase_every = increase_every + self.increase_rate = increase_rate + self.max_iters = max_iters + self.current_iters = start_iters + self.accumulation_iters = accumulation_iters + self.automatic_optimization = False + self.threshold = 1e-2 - def forward( - self, - x: torch.Tensor, - c: torch.Tensor, - edge_index: torch.Tensor, - edge_attr: torch.Tensor, - unrolling_steps: int = None, - boundary_mask: torch.Tensor = None, - boundary_values: torch.Tensor = None, - ): - return self.model( - x=x, - c=c, - edge_index=edge_index, - edge_attr=edge_attr, - unrolling_steps=unrolling_steps, - boundary_mask=boundary_mask, - boundary_values=boundary_values, - ) + def _compute_deg(self, edge_index, edge_attr, num_nodes): + deg = torch.zeros(num_nodes, device=edge_index.device) + deg = deg.scatter_add(0, edge_index[1], edge_attr) + return deg + 1e-7 def _compute_loss(self, x, y): return self.loss(x, y) @@ -89,89 +89,207 @@ class GraphSolver(LightningModule): ) return loss - def training_step(self, batch: Batch, _): + @staticmethod + def _compute_c_ij(c, edge_index): + """ + TODO: add docstring. + """ + return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() + + def _compute_model_steps( + self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values + ): + out = self.model(x, edge_index, edge_attr, deg) + out[boundary_mask] = boundary_values.unsqueeze(-1) + return out + + def _check_convergence(self, out, x): + residual_norm = torch.norm(out - x) + if residual_norm < self.threshold: + return True + return False + + def accumulate_gradients(self, losses, max_acc_iters): + loss_ = torch.stack(losses, dim=0).mean() + loss = loss_ / self.accumulation_iters + self.manual_backward(loss / max_acc_iters) + return loss_.item() + + def training_step(self, batch: Batch, batch_idx: int): + optim = self.optimizers() + optim.zero_grad() x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - y_pred, it = self( - x, - c, - edge_index=edge_index, - edge_attr=edge_attr, - unrolling_steps=self.unrolling_steps, - boundary_mask=batch.boundary_mask, - boundary_values=batch.boundary_values, + + edge_w = 1 / edge_attr[:, -1] + c_ij = self._compute_c_ij(c, edge_index) + edge_w = edge_w * c_ij + deg = self._compute_deg(edge_index, edge_w, x.size(0)) + + edge_attr = torch.cat( + [edge_attr, edge_w.unsqueeze(-1), c_ij.unsqueeze(-1)], dim=1 ) - loss = self.loss(y_pred, y) - boundary_loss = self.loss( - y_pred[batch.boundary_mask], y[batch.boundary_mask] + losses = [] + acc_loss, acc_it = 0, 0 + max_acc_iters = ( + self.current_iters // self.accumulation_iters + 1 + if self.accumulation_iters is not None + else 1 ) - self._log_loss(loss, batch, "train") - # self._log_loss(boundary_loss, batch, "train_boundary") + for i in range(self.current_iters): + out = self._compute_model_steps( + x, + edge_index, + edge_attr, + deg, + batch.boundary_mask, + batch.boundary_values, + ) + losses.append(self.loss(out, y)) + + # Accumulate gradients if reached accumulation iters + if ( + self.accumulation_iters is not None + and (i + 1) % self.accumulation_iters == 0 + ): + loss = self.accumulate_gradients(losses, max_acc_iters) + losses = [] + acc_it += 1 + out = out.detach() + acc_loss = acc_loss + loss + + # Check for convergence and break if converged (with final accumulation) + converged = self._check_convergence(out, x) + if converged: + if losses: + loss = self.accumulate_gradients(losses, max_acc_iters) + acc_it += 1 + acc_loss = acc_loss + loss + break + + # Final accumulation if we are at the last iteration + if i == self.current_iters - 1: + if losses: + loss = self.accumulate_gradients(losses, max_acc_iters) + acc_it += 1 + acc_loss = acc_loss + loss + + x = out + + optim.step() + optim.zero_grad() + + self._log_loss(acc_loss / acc_it, batch, "train") self.log( "train/iterations", - it, + i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) - self.log( - "train/param_p", - self.model.fd_step.p, - on_step=False, - on_epoch=True, - prog_bar=True, - batch_size=int(batch.num_graphs), - ) - # self.log("train/param_a", self.model.fd_step.a, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs)) - return loss + + def on_train_epoch_end(self): + if self.curriculum_learning: + if (self.current_iters < self.max_iters) and ( + self.current_epoch % self.increase_every == 0 + ): + self.current_iters = min( + int(self.current_iters * self.increase_rate), self.max_iters + ) + return super().on_train_epoch_end() def validation_step(self, batch: Batch, _): x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - y_pred, it = self( - x, - c, - edge_index=edge_index, - edge_attr=edge_attr, - unrolling_steps=self.unrolling_steps, - boundary_mask=batch.boundary_mask, - boundary_values=batch.boundary_values, - ) - loss = self.loss(y_pred, y) - boundary_loss = self.loss( - y_pred[batch.boundary_mask], y[batch.boundary_mask] + + edge_w = 1 / edge_attr[:, -1] + c_ij = self._compute_c_ij(c, edge_index) + edge_w = edge_w * c_ij + deg = self._compute_deg(edge_index, edge_w, x.size(0)) + + edge_attr = torch.cat( + [edge_attr, edge_w.unsqueeze(-1), c_ij.unsqueeze(-1)], dim=1 ) + for i in range(self.current_iters): + out = self._compute_model_steps( + x, + edge_index, + edge_attr, + deg, + batch.boundary_mask, + batch.boundary_values, + ) + converged = self._check_convergence(out, x) + + if converged: + break + x = out + loss = self.loss(out, y) self._log_loss(loss, batch, "val") self.log( "val/iterations", - it, + i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs), ) - return loss def test_step(self, batch: Batch, _): + # x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) + # y_pred, _ = self.model( + # x, + # edge_index, + # edge_attr, + # c, + # batch.boundary_mask, + # batch.boundary_values, + # y=None, + # loss_fn=None, + # max_iters=1000, + # plot_results=True, + # batch=batch, + # ) + # loss = self._compute_loss(y_pred, y) + # # _plot_mesh(batch.pos, y, y_pred, batch.batch) + # self._log_loss(loss, batch, "test") x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - y_pred, _ = self.model( - x=x, - c=c, - edge_index=edge_index, - edge_attr=edge_attr, - unrolling_steps=self.unrolling_steps, - batch=batch.batch, - pos=batch.pos, - boundary_mask=batch.boundary_mask, - boundary_values=batch.boundary_values, - plot_results=False, + edge_w = 1 / edge_attr[:, -1] + c_ij = self._compute_c_ij(c, edge_index) + edge_w = edge_w * c_ij + deg = self._compute_deg(edge_index, edge_w, x.size(0)) + + edge_attr = torch.cat( + [edge_attr, edge_w.unsqueeze(-1), c_ij.unsqueeze(-1)], dim=1 ) - loss = self._compute_loss(y_pred, y) - _plot_mesh(batch.pos, y, y_pred, batch.batch) + for i in range(self.max_iters): + out = self._compute_model_steps( + x, + edge_index, + edge_attr, + deg, + batch.boundary_mask, + batch.boundary_values, + ) + converged = self._check_convergence(out, x) + _plot_mesh(batch.pos, y, out, batch.batch, i) + if converged: + break + x = out + loss = self.loss(out, y) + self._log_loss(loss, batch, "test") - return loss + x = u + self.log( + "test/iterations", + i + 1, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=int(batch.num_graphs), + ) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3) return optimizer def _impose_bc(self, x: torch.Tensor, data: Batch): diff --git a/ThermalSolver/model/__init__.py b/ThermalSolver/model/__init__.py index 37adaca..1b87f07 100644 --- a/ThermalSolver/model/__init__.py +++ b/ThermalSolver/model/__init__.py @@ -1,13 +1,13 @@ __all__ = [ - "GraphFiniteDifference", + # "GraphFiniteDifference", "GatingGNO", - "LearnableGraphFiniteDifference", + # "LearnableGraphFiniteDifference", "PointNet", ] -from .learnable_finite_difference import ( - GraphFiniteDifference as LearnableGraphFiniteDifference, -) -from .finite_difference import GraphFiniteDifference as GraphFiniteDifference +# from .learnable_finite_difference import ( +# GraphFiniteDifference as LearnableGraphFiniteDifference, +# ) +# from .finite_difference import GraphFiniteDifference as GraphFiniteDifference from .local_gno import GatingGNO from .point_net import PointNet diff --git a/ThermalSolver/model/finite_difference.py b/ThermalSolver/model/finite_difference.py index fe50eb3..f183f52 100644 --- a/ThermalSolver/model/finite_difference.py +++ b/ThermalSolver/model/finite_difference.py @@ -14,7 +14,7 @@ class FiniteDifferenceStep(MessagePassing): aggr == "add" ), "Per somme pesate, l'aggregazione deve essere 'add'." # self.root_weight = float(root_weight) - self.p = torch.nn.Parameter(torch.tensor(0.8)) + self.p = torch.nn.Parameter(torch.tensor(1.0)) self.a = root_weight def forward(self, x, edge_index, edge_attr, deg): @@ -43,9 +43,7 @@ class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ - a = torch.clamp(self.a, 0.0, 1.0) - return a * aggr_out + (1 - a) * x - # return self.a * aggr_out + (1 - self.a) * x + return aggr_out class GraphFiniteDifference(nn.Module): diff --git a/ThermalSolver/model/learnable_finite_difference.py b/ThermalSolver/model/learnable_finite_difference.py index 8dc6b54..28201bb 100644 --- a/ThermalSolver/model/learnable_finite_difference.py +++ b/ThermalSolver/model/learnable_finite_difference.py @@ -2,6 +2,40 @@ import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm +from matplotlib.tri import Triangulation +from matplotlib import pyplot as plt + + +def _plot_mesh(y_pred, batch, iteration=None): + + idx = batch.batch == 0 + y = batch.y[idx].detach().cpu() + y_pred = y_pred[idx].detach().cpu() + pos = batch.pos[idx].detach().cpu() + + pos = pos.detach().cpu() + tria = Triangulation(pos[:, 0], pos[:, 1]) + plt.figure(figsize=(18, 5)) + plt.subplot(1, 3, 1) + plt.tricontourf(tria, y.squeeze().numpy(), levels=14) + plt.colorbar() + plt.title("True temperature") + plt.subplot(1, 3, 2) + plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) + plt.colorbar() + plt.title("Predicted temperature") + plt.subplot(1, 3, 3) + plt.tricontourf(tria, torch.abs(y_pred - y).squeeze().numpy(), levels=14) + plt.colorbar() + plt.title("Error") + plt.suptitle("GNO", fontsize=16) + name = ( + f"images/gno_iter_{iteration:04d}.png" + if iteration is not None + else "gno.png" + ) + plt.savefig(name, dpi=72) + plt.close() class FiniteDifferenceStep(MessagePassing): @@ -9,50 +43,69 @@ class FiniteDifferenceStep(MessagePassing): TODO: add docstring. """ - def __init__(self, aggr: str = "add", root_weight: float = 1.0): + def __init__(self, edge_ch=5, hidden_dim=16, aggr: str = "add"): super().__init__(aggr=aggr) - assert ( - aggr == "add" - ), "Per somme pesate, l'aggregazione deve essere 'add'." - - self.correction_net = nn.Sequential( - nn.Linear(2, 6), - nn.Tanh(), - nn.Linear(6, 1), - nn.Tanh(), + self.x_embedding = nn.Sequential( + spectral_norm(nn.Linear(1, hidden_dim // 2)), + nn.GELU(), + spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), ) + + self.edge_embedding = nn.Sequential( + spectral_norm(nn.Linear(edge_ch, hidden_dim // 2)), + nn.GELU(), + spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), + ) + self.update_net = nn.Sequential( - spectral_norm(nn.Linear(1, 6)), - nn.Softplus(), - spectral_norm(nn.Linear(6, 1)), - nn.Softplus(), + spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), + nn.GELU(), + spectral_norm(nn.Linear(hidden_dim, hidden_dim)), + nn.GELU(), + # spectral_norm(nn.Linear(hidden_dim // 2, 1)), ) - self.message_net = nn.Sequential( - spectral_norm(nn.Linear(1, 6)), - nn.Softplus(), - spectral_norm(nn.Linear(6, 1)), - nn.Softplus(), + # self.message_net = nn.Sequential( + # spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), + # nn.GELU(), + # spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), + # nn.GELU(), + # spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), + # ) + + self.out_net = nn.Sequential( + spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), + nn.GELU(), + spectral_norm(nn.Linear(hidden_dim // 2, 1)), ) - self.p = torch.nn.Parameter(torch.tensor(0.5)) - # self.a = torch.nn.Parameter(torch.tensor(root_weight)) def forward(self, x, edge_index, edge_attr, deg): """ TODO: add docstring. """ - out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) - return out + x_ = self.x_embedding(x) + edge_attr_ = self.edge_embedding(edge_attr) + out = self.propagate(edge_index, x=x_, edge_attr=edge_attr_, deg=deg) + return self.out_net(x_ + out) def message(self, x_j, edge_attr): """ TODO: add docstring. """ - # x_in = torch.cat([x_j, edge_attr.view(-1, 1)], dim=-1) - # correction = self.correction_net(x_in) - # p = torch.sigmoid(self.p) - # return (p * edge_attr.view(-1, 1) + (1 - p) * correction) * x_j - return edge_attr.view(-1, 1) * x_j + # msg_input = torch.cat([x_j, edge_attr], dim=-1) + # return self.message_net(msg_input) * edge_attr[:, 3].view(-1, 1) + return x_j * edge_attr + + def update(self, aggr_out, x): + """ + TODO: add docstring. + """ + update_input = torch.cat([x, aggr_out], dim=-1) + return self.update_net(update_input) + # return self.update_net(aggr_out) + # return aggr_out + # h = self.update_net(aggr_out, x) + # return h def aggregate(self, inputs, index, deg): """ @@ -62,68 +115,12 @@ class FiniteDifferenceStep(MessagePassing): deg = deg + 1e-7 return out / deg.view(-1, 1) - def update(self, aggr_out, x): - """ - TODO: add docstring. - """ - return self.update_net(aggr_out) +# # Da fare: +# # - Finire calcolo della loss su ogni step e poi media +# # - Test con vari modelli +# # - Se non dovesse funzionare, provare ad adeguare il criterio di uscita -class GraphFiniteDifference(nn.Module): - """ - TODO: add docstring. - """ - - def __init__(self, max_iters: int = 5000, threshold: float = 1e-4): - """ - TODO: add docstring. - """ - super().__init__() - self.max_iters = max_iters - self.threshold = threshold - self.fd_step = FiniteDifferenceStep(aggr="add", root_weight=1.0) - - @staticmethod - def _compute_deg(edge_index, edge_attr, num_nodes): - """ - TODO: add docstring. - """ - deg = torch.zeros(num_nodes, device=edge_index.device) - deg = deg.scatter_add(0, edge_index[1], edge_attr) - return deg + 1e-7 - - @staticmethod - def _compute_c_ij(c, edge_index): - """ - TODO: add docstring. - """ - return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() - - def forward( - self, - x, - edge_index, - edge_attr, - c, - boundary_mask, - boundary_values, - **kwargs, - ): - """ - TODO: add docstring. - """ - edge_attr = 1 / edge_attr[:, -1] - c_ij = self._compute_c_ij(c, edge_index) - edge_attr = edge_attr * c_ij - deg = self._compute_deg(edge_index, edge_attr, x.size(0)) - conv_thres = self.threshold * torch.norm(x.detach()) - - for _i in range(self.max_iters): - out = self.fd_step(x, edge_index, edge_attr, deg) - out[boundary_mask] = boundary_values.unsqueeze(-1) - with torch.no_grad(): - residual_norm = torch.norm(out - x) - if residual_norm < conv_thres: - break - x = out.detach() - return out, _i + 1 +# # PINN batching: +# # - Provare singola condizione +# # - Ottimizzatore del secondo ordine (LBFGS)