import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm from matplotlib.tri import Triangulation from matplotlib import pyplot as plt def _plot_mesh(y_pred, batch, iteration=None): idx = batch.batch == 0 y = batch.y[idx].detach().cpu() y_pred = y_pred[idx].detach().cpu() pos = batch.pos[idx].detach().cpu() pos = pos.detach().cpu() tria = Triangulation(pos[:, 0], pos[:, 1]) plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.tricontourf(tria, y.squeeze().numpy(), levels=14) plt.colorbar() plt.title("True temperature") plt.subplot(1, 3, 2) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) plt.colorbar() plt.title("Predicted temperature") plt.subplot(1, 3, 3) plt.tricontourf(tria, torch.abs(y_pred - y).squeeze().numpy(), levels=14) plt.colorbar() plt.title("Error") plt.suptitle("GNO", fontsize=16) name = ( f"images/gno_iter_{iteration:04d}.png" if iteration is not None else "gno.png" ) plt.savefig(name, dpi=72) plt.close() class FiniteDifferenceStep(MessagePassing): """ TODO: add docstring. """ def __init__(self, edge_ch=5, hidden_dim=16, aggr: str = "add"): super().__init__(aggr=aggr) self.x_embedding = nn.Sequential( spectral_norm(nn.Linear(1, hidden_dim // 2)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), ) self.edge_embedding = nn.Sequential( spectral_norm(nn.Linear(edge_ch, hidden_dim // 2)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), ) self.update_net = nn.Sequential( spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim, hidden_dim)), nn.GELU(), # spectral_norm(nn.Linear(hidden_dim // 2, 1)), ) # self.message_net = nn.Sequential( # spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), # nn.GELU(), # spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), # nn.GELU(), # spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), # ) self.out_net = nn.Sequential( spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim // 2, 1)), ) def forward(self, x, edge_index, edge_attr, deg): """ TODO: add docstring. """ x_ = self.x_embedding(x) edge_attr_ = self.edge_embedding(edge_attr) out = self.propagate(edge_index, x=x_, edge_attr=edge_attr_, deg=deg) return self.out_net(x_ + out) def message(self, x_j, edge_attr): """ TODO: add docstring. """ # msg_input = torch.cat([x_j, edge_attr], dim=-1) # return self.message_net(msg_input) * edge_attr[:, 3].view(-1, 1) return x_j * edge_attr def update(self, aggr_out, x): """ TODO: add docstring. """ update_input = torch.cat([x, aggr_out], dim=-1) return self.update_net(update_input) # return self.update_net(aggr_out) # return aggr_out # h = self.update_net(aggr_out, x) # return h def aggregate(self, inputs, index, deg): """ TODO: add docstring. """ out = super().aggregate(inputs, index) deg = deg + 1e-7 return out / deg.view(-1, 1) # # Da fare: # # - Finire calcolo della loss su ogni step e poi media # # - Test con vari modelli # # - Se non dovesse funzionare, provare ad adeguare il criterio di uscita # # PINN batching: # # - Provare singola condizione # # - Ottimizzatore del secondo ordine (LBFGS)