add radius graph option
This commit is contained in:
@@ -18,6 +18,8 @@ class GraphDataModule(LightningDataModule):
|
|||||||
test_size: float = 0.1,
|
test_size: float = 0.1,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
remove_boundary_edges: bool = False,
|
remove_boundary_edges: bool = False,
|
||||||
|
build_radial_graph: bool = False,
|
||||||
|
radius: float = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hf_repo = hf_repo
|
self.hf_repo = hf_repo
|
||||||
@@ -29,6 +31,8 @@ class GraphDataModule(LightningDataModule):
|
|||||||
self.test_size = test_size
|
self.test_size = test_size
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.remove_boundary_edges = remove_boundary_edges
|
self.remove_boundary_edges = remove_boundary_edges
|
||||||
|
self.build_radial_graph = build_radial_graph
|
||||||
|
self.radius = radius
|
||||||
|
|
||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name]
|
dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name]
|
||||||
@@ -80,9 +84,8 @@ class GraphDataModule(LightningDataModule):
|
|||||||
)
|
)
|
||||||
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
|
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
|
||||||
|
|
||||||
edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T
|
|
||||||
|
|
||||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||||
|
|
||||||
bottom_ids = torch.tensor(
|
bottom_ids = torch.tensor(
|
||||||
geometry["bottom_boundary_ids"], dtype=torch.long
|
geometry["bottom_boundary_ids"], dtype=torch.long
|
||||||
)
|
)
|
||||||
@@ -92,11 +95,28 @@ class GraphDataModule(LightningDataModule):
|
|||||||
geometry["right_boundary_ids"], dtype=torch.long
|
geometry["right_boundary_ids"], dtype=torch.long
|
||||||
)
|
)
|
||||||
|
|
||||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
if self.build_radial_graph:
|
||||||
|
from pina.graph import RadiusGraph
|
||||||
|
|
||||||
|
if self.radius is None:
|
||||||
|
raise ValueError("Radius must be specified for radial graph.")
|
||||||
|
edge_index = RadiusGraph.compute_radius_graph(
|
||||||
|
pos, radius=self.radius
|
||||||
|
)
|
||||||
|
from torch_geometric.utils import remove_self_loops
|
||||||
|
|
||||||
|
edge_index, _ = remove_self_loops(edge_index)
|
||||||
|
else:
|
||||||
|
|
||||||
|
edge_index = torch.tensor(
|
||||||
|
geometry["edge_index"], dtype=torch.int64
|
||||||
|
).T
|
||||||
|
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||||
|
|
||||||
boundary_mask, boundary_values = self._compute_boundary_mask(
|
boundary_mask, boundary_values = self._compute_boundary_mask(
|
||||||
bottom_ids, right_ids, top_ids, left_ids, temperature
|
bottom_ids, right_ids, top_ids, left_ids, temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.remove_boundary_edges:
|
if self.remove_boundary_edges:
|
||||||
boundary_idx = torch.unique(boundary_mask)
|
boundary_idx = torch.unique(boundary_mask)
|
||||||
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
||||||
|
|||||||
Reference in New Issue
Block a user