From 6e90ef5393d7b8396bada132ce076867f86a65c5 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Mon, 27 Oct 2025 10:23:13 +0100 Subject: [PATCH] random changes --- ThermalSolver/graph_datamodule.py | 9 +- ThermalSolver/graph_module.py | 86 +++++++-- ThermalSolver/model/__init__.py | 2 +- ThermalSolver/model/basic_gno.py | 25 --- ThermalSolver/model/finite_difference.py | 56 +++--- ThermalSolver/model/point_net.py | 216 ++++++++++++++++++++++- ThermalSolver/point_module.py | 11 +- 7 files changed, 325 insertions(+), 80 deletions(-) delete mode 100644 ThermalSolver/model/basic_gno.py diff --git a/ThermalSolver/graph_datamodule.py b/ThermalSolver/graph_datamodule.py index c5d4ce5..a3c1145 100644 --- a/ThermalSolver/graph_datamodule.py +++ b/ThermalSolver/graph_datamodule.py @@ -6,7 +6,6 @@ from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torch_geometric.utils import to_undirected from .mesh_data import MeshData -import os class GraphDataModule(LightningDataModule): @@ -18,7 +17,7 @@ class GraphDataModule(LightningDataModule): val_size: float = 0.1, test_size: float = 0.1, batch_size: int = 32, - remove_boundary_edges: bool = True, + remove_boundary_edges: bool = False, ): super().__init__() self.hf_repo = hf_repo @@ -82,6 +81,7 @@ class GraphDataModule(LightningDataModule): temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32) edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T + pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] bottom_ids = torch.tensor( geometry["bottom_boundary_ids"], dtype=torch.long @@ -97,7 +97,6 @@ class GraphDataModule(LightningDataModule): boundary_mask, boundary_values = self._compute_boundary_mask( bottom_ids, right_ids, top_ids, left_ids, temperature ) - if self.remove_boundary_edges: boundary_idx = torch.unique(boundary_mask) edge_index_mask = ~torch.isin(edge_index[1], boundary_idx) @@ -119,7 +118,7 @@ class GraphDataModule(LightningDataModule): edge_attr=edge_attr, y=temperature.unsqueeze(-1), boundary_mask=boundary_mask, - boundary_values=torch.tensor(0), # Fake value (to fix) + boundary_values=boundary_values, ) return MeshData( @@ -129,7 +128,7 @@ class GraphDataModule(LightningDataModule): pos=pos, edge_attr=edge_attr, boundary_mask=boundary_mask, - boundary_values=boundary_values.unsqueeze(-1), + boundary_values=boundary_values, y=temperature.unsqueeze(-1), ) diff --git a/ThermalSolver/graph_module.py b/ThermalSolver/graph_module.py index b530310..90cc8b0 100644 --- a/ThermalSolver/graph_module.py +++ b/ThermalSolver/graph_module.py @@ -2,6 +2,8 @@ import torch from lightning import LightningModule from torch_geometric.data import Batch import importlib +from matplotlib import pyplot as plt +from matplotlib.tri import Triangulation def import_class(class_path: str): @@ -11,6 +13,32 @@ def import_class(class_path: str): return cls +def _plot_mesh(pos, y, y_pred, batch): + + idx = batch == 0 + y = y[idx].detach().cpu() + y_pred = y_pred[idx].detach().cpu() + pos = pos[idx].detach().cpu() + + pos = pos.detach().cpu() + tria = Triangulation(pos[:, 0], pos[:, 1]) + plt.figure(figsize=(18, 5)) + plt.subplot(1, 3, 1) + plt.tricontourf(tria, y.squeeze().numpy(), levels=14) + plt.colorbar() + plt.title("True temperature") + plt.subplot(1, 3, 2) + plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) + plt.colorbar() + plt.title("Predicted temperature") + plt.subplot(1, 3, 3) + plt.tricontourf(tria, torch.abs(y_pred - y).squeeze().numpy(), levels=14) + plt.colorbar() + plt.title("Error") + plt.suptitle("GNO", fontsize=16) + plt.savefig("gno.png", dpi=300) + + class GraphSolver(LightningModule): def __init__( self, @@ -32,14 +60,16 @@ class GraphSolver(LightningModule): edge_attr: torch.Tensor, unrolling_steps: int = None, boundary_mask: torch.Tensor = None, + boundary_values: torch.Tensor = None, ): return self.model( - x, - c, - edge_index, - edge_attr, - unrolling_steps, + x=x, + c=c, + edge_index=edge_index, + edge_attr=edge_attr, + unrolling_steps=unrolling_steps, boundary_mask=boundary_mask, + boundary_values=boundary_values, ) def _compute_loss(self, x, y): @@ -61,52 +91,82 @@ class GraphSolver(LightningModule): def training_step(self, batch: Batch, _): x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - y_pred = self( + y_pred, it = self( x, c, edge_index=edge_index, edge_attr=edge_attr, unrolling_steps=self.unrolling_steps, boundary_mask=batch.boundary_mask, + boundary_values=batch.boundary_values, ) loss = self.loss(y_pred, y) boundary_loss = self.loss( y_pred[batch.boundary_mask], y[batch.boundary_mask] ) self._log_loss(loss, batch, "train") - self._log_loss(boundary_loss, batch, "train_boundary") + # self._log_loss(boundary_loss, batch, "train_boundary") + self.log( + "train/iterations", + it, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=int(batch.num_graphs), + ) + self.log( + "train/param_p", + self.model.fd_step.p, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=int(batch.num_graphs), + ) + # self.log("train/param_a", self.model.fd_step.a, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs)) return loss def validation_step(self, batch: Batch, _): x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - y_pred = self( + y_pred, it = self( x, c, edge_index=edge_index, edge_attr=edge_attr, unrolling_steps=self.unrolling_steps, + boundary_mask=batch.boundary_mask, + boundary_values=batch.boundary_values, ) loss = self.loss(y_pred, y) boundary_loss = self.loss( y_pred[batch.boundary_mask], y[batch.boundary_mask] ) self._log_loss(loss, batch, "val") - self._log_loss(boundary_loss, batch, "val_boundary") + self.log( + "val/iterations", + it, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=int(batch.num_graphs), + ) return loss def test_step(self, batch: Batch, _): x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - y_pred = self.model( - x, - c, + y_pred, _ = self.model( + x=x, + c=c, edge_index=edge_index, edge_attr=edge_attr, unrolling_steps=self.unrolling_steps, batch=batch.batch, pos=batch.pos, - plot_results=True, + boundary_mask=batch.boundary_mask, + boundary_values=batch.boundary_values, + plot_results=False, ) loss = self._compute_loss(y_pred, y) + _plot_mesh(batch.pos, y, y_pred, batch.batch) self._log_loss(loss, batch, "test") return loss diff --git a/ThermalSolver/model/__init__.py b/ThermalSolver/model/__init__.py index cf8c596..2538982 100644 --- a/ThermalSolver/model/__init__.py +++ b/ThermalSolver/model/__init__.py @@ -1,5 +1,5 @@ __all__ = ["GraphFiniteDifference", "GatingGNO"] -from .finite_difference import GraphFiniteDifference +from .learnable_finite_difference import GraphFiniteDifference from .local_gno import GatingGNO from .point_net import PointNet diff --git a/ThermalSolver/model/basic_gno.py b/ThermalSolver/model/basic_gno.py deleted file mode 100644 index bb76f0b..0000000 --- a/ThermalSolver/model/basic_gno.py +++ /dev/null @@ -1,25 +0,0 @@ -from pina.model import GraphNeuralOperator -import torch -from torch_geometric.data import Data - - -class GNO(torch.nn.Module): - def __init__( - self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1 - ): - super().__init__() - - lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden) - self.gno = GraphNeuralOperator( - lifting_operator=lifting_operator, - projection_operator=torch.nn.Linear(hidden, out_ch), - edge_features=edge_ch, - n_layers=layers, - internal_n_layers=2, - shared_weights=False, - ) - - def forward(self, x, c, edge_index, edge_attr): - x = torch.cat([x, c], dim=-1) - x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - return self.gno(x) diff --git a/ThermalSolver/model/finite_difference.py b/ThermalSolver/model/finite_difference.py index 28bee77..814eb84 100644 --- a/ThermalSolver/model/finite_difference.py +++ b/ThermalSolver/model/finite_difference.py @@ -9,32 +9,27 @@ class FiniteDifferenceStep(MessagePassing): TODO: add docstring. """ - def __init__( - self, - aggr: str = "add", - normalize: bool = True, - root_weight: float = 1.0, - ): + def __init__(self, aggr: str = "add", root_weight: float = 1.0): super().__init__(aggr=aggr) - - self.normalize = normalize assert ( aggr == "add" ), "Per somme pesate, l'aggregazione deve essere 'add'." self.root_weight = float(root_weight) - def forward(self, x, edge_index, edge_weight, deg): + def forward(self, x, edge_index, edge_attr, deg, weight=1.0): """ TODO: add docstring. """ - out = self.propagate(edge_index, x=x, edge_weight=edge_weight, deg=deg) + out = self.propagate( + edge_index, x=x, edge_attr=edge_attr, deg=deg, weight=weight + ) return out - def message(self, x_j, edge_weight): + def message(self, x_j, edge_attr): """ TODO: add docstring. """ - return edge_weight.view(-1, 1) * x_j + return edge_attr.view(-1, 1) * x_j def aggregate(self, inputs, index, deg): """ @@ -44,11 +39,12 @@ class FiniteDifferenceStep(MessagePassing): deg = deg + 1e-7 return out / deg.view(-1, 1) - def update(self, aggr_out, x): + def update(self, aggr_out, x, weight): """ TODO: add docstring. """ - return self.root_weight * aggr_out + (1 - self.root_weight) * x + print(weight) + return weight * aggr_out + (1 - weight) * x class GraphFiniteDifference(nn.Module): @@ -56,24 +52,22 @@ class GraphFiniteDifference(nn.Module): TODO: add docstring. """ - def __init__(self, max_iters: int = 1000, threshold: float = 1e-4): + def __init__(self, max_iters: int = 5000, threshold: float = 1e-4): """ TODO: add docstring. """ super().__init__() self.max_iters = max_iters self.threshold = threshold - self.fd_step = FiniteDifferenceStep( - aggr="add", normalize=True, root_weight=1.0 - ) + self.fd_step = FiniteDifferenceStep(aggr="add", root_weight=1.0) @staticmethod - def _compute_deg(edge_index, edge_weight, num_nodes): + def _compute_deg(edge_index, edge_attr, num_nodes): """ TODO: add docstring. """ deg = torch.zeros(num_nodes, device=edge_index.device) - deg = deg.scatter_add(0, edge_index[1], edge_weight) + deg = deg.scatter_add(0, edge_index[1], edge_attr) return deg + 1e-7 @staticmethod @@ -84,19 +78,29 @@ class GraphFiniteDifference(nn.Module): return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() def forward( - self, x, edge_index, edge_weight, c, boundary_mask, boundary_values + self, + x, + edge_index, + edge_attr, + c, + boundary_mask, + boundary_values, + **kwargs, ): """ TODO: add docstring. """ + edge_attr = 1 / edge_attr[:, -1] c_ij = self._compute_c_ij(c, edge_index) - edge_weight = edge_weight * c_ij - deg = self._compute_deg(edge_index, edge_weight, x.size(0)) + edge_attr = edge_attr * c_ij + deg = self._compute_deg(edge_index, edge_attr, x.size(0)) conv_thres = self.threshold * torch.norm(x) - for _i in tqdm(range(self.max_iters)): - out = self.fd_step(x, edge_index, edge_weight, deg) + weight = 1.0 + for _i in range(self.max_iters): + out = self.fd_step(x, edge_index, edge_attr, deg, weight=weight) + weight = weight * 0.9999 out[boundary_mask] = boundary_values.unsqueeze(-1) if torch.norm(out - x) < conv_thres: break x = out - return out + return out, _i + 1 diff --git a/ThermalSolver/model/point_net.py b/ThermalSolver/model/point_net.py index f0725e1..a4f4d01 100644 --- a/ThermalSolver/model/point_net.py +++ b/ThermalSolver/model/point_net.py @@ -108,14 +108,14 @@ class MLP(torch.nn.Module): tmp_layers.append(self._output_dim) self._layers = [] - self._LayerNorm = [] + self._batchnorm = [] for i in range(len(tmp_layers) - 1): self._layers.append( self.spect_norm(nn.Linear(tmp_layers[i], tmp_layers[i + 1])) ) - self._LayerNorm.append(nn.LazyLayerNorm()) + self._batchnorm.append(nn.LazyBatchNorm1d()) if isinstance(func, list): self._functions = func @@ -124,7 +124,7 @@ class MLP(torch.nn.Module): unique_list = [] for layer, func, bnorm in zip( - self._layers[:-1], self._functions, self._LayerNorm + self._layers[:-1], self._functions, self._batchnorm ): unique_list.append(layer) @@ -208,7 +208,7 @@ class TNet(nn.Module): ) self._function = function() - self._bn1 = nn.LazyLayerNorm() + self._bn1 = nn.LazyBatchNorm1d() def forward(self, X): """Forward pass for T-Net @@ -299,9 +299,9 @@ class PointNet(nn.Module): self._tnet_feature = TNet(input_dim=64) self._function = function() - self._bn1 = nn.LazyLayerNorm() - self._bn2 = nn.LazyLayerNorm() - self._bn3 = nn.LazyLayerNorm() + self._bn1 = nn.LazyBatchNorm1d() + self._bn2 = nn.LazyBatchNorm1d() + self._bn3 = nn.LazyBatchNorm1d() def concat(self, embedding, input_): """Returns concatenation of global and local features for Point-Net @@ -370,3 +370,205 @@ class PointNet(nn.Module): X = self._mlp4(X) return X + + +class ConvTNet(nn.Module): + """T-Net base class. Implementation of T-Network with convolutional layers. + + Reference: Ali Kashefi et al. https://arxiv.org/abs/2208.13434 + """ + + def __init__(self, input_dim): + """T-Net block constructor + + :param input_dim: input dimension of point cloud + :type input_dim: int + """ + super().__init__() + + function = nn.Tanh + self._function = function() + + self._block1 = nn.Sequential( + nn.Conv1d(input_dim, 64, 1), + nn.BatchNorm1d(64), + self._function, + nn.Conv1d(64, 128, 1), + nn.BatchNorm1d(128), + self._function, + nn.Conv1d(128, 1024, 1), + nn.BatchNorm1d(1024), + self._function, + ) + + self._block2 = MLP( + input_dim=1024, + output_dim=input_dim * input_dim, + layers=[512, 256], + func=function, + batch_norm=True, + ) + + def forward(self, X): + """Forward pass for T-Net + + :param X: input tensor, shape [batch, $input_{dim}$, N] + with batch the batch size, N number of points and $input_{dim}$ + the input dimension of the point cloud. + :type X: torch.tensor + :return: output affine matrix transformation, shape + [batch, $input_{dim} \times input_{dim}$] with batch + the batch size and $input_{dim}$ the input dimension + of the point cloud. + :rtype: torch.tensor + """ + + batch, input_dim = X.shape[0], X.shape[1] + + # encoding using first MLP + X = self._block1(X) + + # applying symmetric function to aggregate information (using max as default) + X, _ = torch.max(X, dim=-1) + + # decoding using third MLP + X = self._block2(X) + + return X.reshape(batch, input_dim, input_dim) + + +class ConvPointNet(nn.Module): + """Point-Net base class. Implementation of Point Network for segmentation. + + Reference: Ali Kashefi et al. https://arxiv.org/abs/2208.13434 + """ + + def __init__(self, input_dim, output_dim, tnet=False): + """Point-Net block constructor + + :param input_dim: input dimension of point cloud + :type input_dim: int + :param output_dim: output dimension of point cloud + :type output_dim: int + :param tnet: apply T-Net transformation, defaults to False + :type tnet: bool, optional + """ + super().__init__() + + self._function = nn.Tanh() + self._use_tnet = tnet + + self._block1 = nn.Sequential( + nn.Conv1d(input_dim, 64, 1), + nn.BatchNorm1d(), + self._function, + nn.Conv1d(64, 64, 1), + nn.BatchNorm1d(64), + self._function, + ) + + self._block2 = nn.Sequential( + nn.Conv1d(64, 64, 1), + nn.BatchNorm1d(64), + self._function, + nn.Conv1d(64, 128, 1), + nn.BatchNorm1d(128), + self._function, + nn.Conv1d(128, 1024, 1), + nn.BatchNorm1d(1024), + self._function, + ) + + self._block3 = nn.Sequential( + nn.Conv1d(1088, 512, 1), + nn.BatchNorm1d(512), + self._function, + nn.Conv1d(512, 256, 1), + nn.BatchNorm1d(256), + self._function, + nn.Conv1d(256, 128, 1), + nn.BatchNorm1d(128), + self._function, + ) + + self._block4 = nn.Conv1d(128, output_dim, 1) + + if self._use_tnet: + self._tnet_transform = ConvTNet(input_dim=input_dim) + self._tnet_feature = ConvTNet(input_dim=64) + + def concat(self, embedding, input_): + """ + Returns concatenation of global and local features for Point-Net + + :param embedding: global features of Point-Net, shape [batch, $input_{dim}$] + with batch the batch size and $input_{dim}$ the input dimension + of the point cloud. + :type embedding: torch.tensor + :param input_: local features of Point-Net, shape [batch, N, $input_{dim}$] + with batch the batch size, N number of points and $input_{dim}$ + the input dimension of the point cloud. + :type input_: torch.tensor + :return: concatenation vector, shape [batch, N, $input_{dim}$] + with batch the batch size, N number of points and $input_{dim}$ + :rtype: torch.tensor + """ + n_points = input_.shape[-1] + embedding = embedding.unsqueeze(2).repeat(1, 1, n_points) + return torch.cat([embedding, input_], dim=1) + + def forward(self, X): + """Forward pass for Point-Net + + :param X: input tensor, shape [batch, N, $input_{dim}$] + with batch the batch size, N number of points and $input_{dim}$ + the input dimension of the point cloud. + :type X: torch.tensor + :return: segmentation vector, shape [batch, N, $output_{dim}$] + with batch the batch size, N number of points and $output_{dim}$ + the output dimension of the point cloud. + :rtype: torch.tensor + """ + + # permuting indeces + X = X.permute(0, 2, 1) + + # using transform tnet if needed + if self._use_tnet: + transform = self._tnet_transform(X) + X = X.transpose(2, 1) + X = torch.matmul(X, transform) + X = X.transpose(2, 1) + + # encoding using first MLP + X = self._block1(X) + + # using transform tnet if needed + if self._use_tnet: + transform = self._tnet_feature(X) + X = X.transpose(2, 1) + X = torch.matmul(X, transform) + X = X.transpose(2, 1) + + # saving latent representation for later concatanation + latent = X + + # encoding using second MLP + X = self._block2(X) + + # applying symmetric function to aggregate information (using max as default) + X, _ = torch.max(X, dim=-1) + + # concatenating with latent vector + X = self.concat(X, latent) + + # decoding using third MLP + X = self._block3(X) + + # decoding using fourth MLP + X = self._block4(X) + + # permuting indeces + X = X.permute(0, 2, 1) + + return X diff --git a/ThermalSolver/point_module.py b/ThermalSolver/point_module.py index a3bf592..c3eae71 100644 --- a/ThermalSolver/point_module.py +++ b/ThermalSolver/point_module.py @@ -15,15 +15,20 @@ def _plot_mesh(x, y, y_pred): y_pred = y_pred[x[:, 0] != -1] tria = Triangulation(pos[:, 2], pos[:, 3]) - plt.figure(figsize=(12, 5)) - plt.subplot(1, 2, 1) + plt.figure(figsize=(18, 5)) + plt.subplot(1, 3, 1) plt.tricontourf(tria, y.squeeze().numpy(), levels=14) plt.colorbar() plt.title("True temperature") - plt.subplot(1, 2, 2) + plt.subplot(1, 3, 2) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) plt.colorbar() plt.title("Predicted temperature") + plt.subplot(1, 3, 3) + plt.tricontourf(tria, torch.abs(y_pred - y).squeeze().numpy(), levels=14) + plt.colorbar() + plt.title("Error") + plt.suptitle("PointNet", fontsize=16) plt.savefig("point_net.png", dpi=300)