add radius graph option

This commit is contained in:
Filippo Olivo
2025-11-07 15:52:34 +01:00
parent 5c5483744c
commit 195c66b444

View File

@@ -18,6 +18,8 @@ class GraphDataModule(LightningDataModule):
test_size: float = 0.1,
batch_size: int = 32,
remove_boundary_edges: bool = False,
build_radial_graph: bool = False,
radius: float = None,
):
super().__init__()
self.hf_repo = hf_repo
@@ -29,6 +31,8 @@ class GraphDataModule(LightningDataModule):
self.test_size = test_size
self.batch_size = batch_size
self.remove_boundary_edges = remove_boundary_edges
self.build_radial_graph = build_radial_graph
self.radius = radius
def prepare_data(self):
dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name]
@@ -80,9 +84,8 @@ class GraphDataModule(LightningDataModule):
)
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
bottom_ids = torch.tensor(
geometry["bottom_boundary_ids"], dtype=torch.long
)
@@ -92,11 +95,28 @@ class GraphDataModule(LightningDataModule):
geometry["right_boundary_ids"], dtype=torch.long
)
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
if self.build_radial_graph:
from pina.graph import RadiusGraph
if self.radius is None:
raise ValueError("Radius must be specified for radial graph.")
edge_index = RadiusGraph.compute_radius_graph(
pos, radius=self.radius
)
from torch_geometric.utils import remove_self_loops
edge_index, _ = remove_self_loops(edge_index)
else:
edge_index = torch.tensor(
geometry["edge_index"], dtype=torch.int64
).T
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
boundary_mask, boundary_values = self._compute_boundary_mask(
bottom_ids, right_ids, top_ids, left_ids, temperature
)
if self.remove_boundary_edges:
boundary_idx = torch.unique(boundary_mask)
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)