From b26403189fb7cd7ada09fd2aadbe0aef0bc689c4 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Thu, 2 Oct 2025 12:20:22 +0200 Subject: [PATCH] add refined geometries in datamodule --- ThermalSolver/data_module.py | 51 +++++++++++--------------------- ThermalSolver/model/local_gno.py | 9 ++++-- 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/ThermalSolver/data_module.py b/ThermalSolver/data_module.py index 127fcea..a65ed22 100644 --- a/ThermalSolver/data_module.py +++ b/ThermalSolver/data_module.py @@ -35,53 +35,36 @@ class GraphDataModule(LightningDataModule): self.geometry = load_dataset(self.hf_repo, name="geometry")[ self.split_name ] - edge_index = torch.tensor( - self.geometry["edge_index"][0], dtype=torch.int64 - ) - pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[ - :, :2 - ] - - bottom_ids = torch.tensor( - self.geometry["bottom_boundary_ids"][0], dtype=torch.long - ) - top_ids = torch.tensor( - self.geometry["top_boundary_ids"][0], dtype=torch.long - ) - left_ids = torch.tensor( - self.geometry["left_boundary_ids"][0], dtype=torch.long - ) - right_ids = torch.tensor( - self.geometry["right_boundary_ids"][0], dtype=torch.long - ) self.data = [ - self._build_dataset( - snapshot, - edge_index.T, - pos, - bottom_ids, - top_ids, - left_ids, - right_ids, + self._build_dataset(snapshot, geometry) + for snapshot, geometry in tqdm( + zip(hf_dataset, self.geometry), + desc="Building graphs", + total=len(hf_dataset), ) - for snapshot in tqdm(hf_dataset, desc="Building graphs") ] def _build_dataset( self, snapshot: dict, - edge_index: torch.Tensor, - pos: torch.Tensor, - bottom_ids: torch.Tensor, - top_ids: torch.Tensor, - left_ids: torch.Tensor, - right_ids: torch.Tensor, + geometry: dict, ) -> Data: conductivity = torch.tensor( snapshot["conductivity"], dtype=torch.float32 ) temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32) + edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T + pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] + bottom_ids = torch.tensor( + geometry["bottom_boundary_ids"], dtype=torch.long + ) + top_ids = torch.tensor(geometry["top_boundary_ids"], dtype=torch.long) + left_ids = torch.tensor(geometry["left_boundary_ids"], dtype=torch.long) + right_ids = torch.tensor( + geometry["right_boundary_ids"], dtype=torch.long + ) + edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) edge_attr = pos[edge_index[0]] - pos[edge_index[1]] edge_attr = torch.cat( diff --git a/ThermalSolver/model/local_gno.py b/ThermalSolver/model/local_gno.py index 6432850..7177038 100644 --- a/ThermalSolver/model/local_gno.py +++ b/ThermalSolver/model/local_gno.py @@ -106,10 +106,13 @@ class ConditionalGNOBlock(MessagePassing): def message(self, x_i, x_j, c_i, c_j, edge_attr): c_ij = 0.5 * (c_i + c_j) alpha = torch.sigmoid(self.balancing) - m = alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j) + gate = torch.sigmoid(self.edge_attr_net(edge_attr)) + m = ( + alpha * self.diff_net(x_j - x_i) + + (1 - alpha) * self.x_net(x_j) * gate + ) m = m * self.c_ij_net(c_ij) - gate = self.edge_attr_net(edge_attr) - return m * torch.sigmoid(gate) + return m def update(self, aggr_out, x): return x + self.alpha * self.msg_proj(aggr_out)