add radius graph option
This commit is contained in:
@@ -18,6 +18,8 @@ class GraphDataModule(LightningDataModule):
|
||||
test_size: float = 0.1,
|
||||
batch_size: int = 32,
|
||||
remove_boundary_edges: bool = False,
|
||||
build_radial_graph: bool = False,
|
||||
radius: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hf_repo = hf_repo
|
||||
@@ -29,6 +31,8 @@ class GraphDataModule(LightningDataModule):
|
||||
self.test_size = test_size
|
||||
self.batch_size = batch_size
|
||||
self.remove_boundary_edges = remove_boundary_edges
|
||||
self.build_radial_graph = build_radial_graph
|
||||
self.radius = radius
|
||||
|
||||
def prepare_data(self):
|
||||
dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name]
|
||||
@@ -80,9 +84,8 @@ class GraphDataModule(LightningDataModule):
|
||||
)
|
||||
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
|
||||
|
||||
edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T
|
||||
|
||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||
|
||||
bottom_ids = torch.tensor(
|
||||
geometry["bottom_boundary_ids"], dtype=torch.long
|
||||
)
|
||||
@@ -92,11 +95,28 @@ class GraphDataModule(LightningDataModule):
|
||||
geometry["right_boundary_ids"], dtype=torch.long
|
||||
)
|
||||
|
||||
if self.build_radial_graph:
|
||||
from pina.graph import RadiusGraph
|
||||
|
||||
if self.radius is None:
|
||||
raise ValueError("Radius must be specified for radial graph.")
|
||||
edge_index = RadiusGraph.compute_radius_graph(
|
||||
pos, radius=self.radius
|
||||
)
|
||||
from torch_geometric.utils import remove_self_loops
|
||||
|
||||
edge_index, _ = remove_self_loops(edge_index)
|
||||
else:
|
||||
|
||||
edge_index = torch.tensor(
|
||||
geometry["edge_index"], dtype=torch.int64
|
||||
).T
|
||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||
|
||||
boundary_mask, boundary_values = self._compute_boundary_mask(
|
||||
bottom_ids, right_ids, top_ids, left_ids, temperature
|
||||
)
|
||||
|
||||
if self.remove_boundary_edges:
|
||||
boundary_idx = torch.unique(boundary_mask)
|
||||
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
||||
|
||||
Reference in New Issue
Block a user