From 195c66b444c89f58903bee8e631c733f6d41bc5a Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Fri, 7 Nov 2025 15:52:34 +0100 Subject: [PATCH] add radius graph option --- ThermalSolver/graph_datamodule.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/ThermalSolver/graph_datamodule.py b/ThermalSolver/graph_datamodule.py index a3c1145..9920aa5 100644 --- a/ThermalSolver/graph_datamodule.py +++ b/ThermalSolver/graph_datamodule.py @@ -18,6 +18,8 @@ class GraphDataModule(LightningDataModule): test_size: float = 0.1, batch_size: int = 32, remove_boundary_edges: bool = False, + build_radial_graph: bool = False, + radius: float = None, ): super().__init__() self.hf_repo = hf_repo @@ -29,6 +31,8 @@ class GraphDataModule(LightningDataModule): self.test_size = test_size self.batch_size = batch_size self.remove_boundary_edges = remove_boundary_edges + self.build_radial_graph = build_radial_graph + self.radius = radius def prepare_data(self): dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name] @@ -80,9 +84,8 @@ class GraphDataModule(LightningDataModule): ) temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32) - edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T - pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] + bottom_ids = torch.tensor( geometry["bottom_boundary_ids"], dtype=torch.long ) @@ -92,11 +95,28 @@ class GraphDataModule(LightningDataModule): geometry["right_boundary_ids"], dtype=torch.long ) - edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) + if self.build_radial_graph: + from pina.graph import RadiusGraph + + if self.radius is None: + raise ValueError("Radius must be specified for radial graph.") + edge_index = RadiusGraph.compute_radius_graph( + pos, radius=self.radius + ) + from torch_geometric.utils import remove_self_loops + + edge_index, _ = remove_self_loops(edge_index) + else: + + edge_index = torch.tensor( + geometry["edge_index"], dtype=torch.int64 + ).T + edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) boundary_mask, boundary_values = self._compute_boundary_mask( bottom_ids, right_ids, top_ids, left_ids, temperature ) + if self.remove_boundary_edges: boundary_idx = torch.unique(boundary_mask) edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)