add refined geometries in datamodule
This commit is contained in:
@@ -35,53 +35,36 @@ class GraphDataModule(LightningDataModule):
|
||||
self.geometry = load_dataset(self.hf_repo, name="geometry")[
|
||||
self.split_name
|
||||
]
|
||||
edge_index = torch.tensor(
|
||||
self.geometry["edge_index"][0], dtype=torch.int64
|
||||
)
|
||||
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
|
||||
:, :2
|
||||
]
|
||||
|
||||
bottom_ids = torch.tensor(
|
||||
self.geometry["bottom_boundary_ids"][0], dtype=torch.long
|
||||
)
|
||||
top_ids = torch.tensor(
|
||||
self.geometry["top_boundary_ids"][0], dtype=torch.long
|
||||
)
|
||||
left_ids = torch.tensor(
|
||||
self.geometry["left_boundary_ids"][0], dtype=torch.long
|
||||
)
|
||||
right_ids = torch.tensor(
|
||||
self.geometry["right_boundary_ids"][0], dtype=torch.long
|
||||
)
|
||||
self.data = [
|
||||
self._build_dataset(
|
||||
snapshot,
|
||||
edge_index.T,
|
||||
pos,
|
||||
bottom_ids,
|
||||
top_ids,
|
||||
left_ids,
|
||||
right_ids,
|
||||
self._build_dataset(snapshot, geometry)
|
||||
for snapshot, geometry in tqdm(
|
||||
zip(hf_dataset, self.geometry),
|
||||
desc="Building graphs",
|
||||
total=len(hf_dataset),
|
||||
)
|
||||
for snapshot in tqdm(hf_dataset, desc="Building graphs")
|
||||
]
|
||||
|
||||
def _build_dataset(
|
||||
self,
|
||||
snapshot: dict,
|
||||
edge_index: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
bottom_ids: torch.Tensor,
|
||||
top_ids: torch.Tensor,
|
||||
left_ids: torch.Tensor,
|
||||
right_ids: torch.Tensor,
|
||||
geometry: dict,
|
||||
) -> Data:
|
||||
conductivity = torch.tensor(
|
||||
snapshot["conductivity"], dtype=torch.float32
|
||||
)
|
||||
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
|
||||
|
||||
edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T
|
||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||
bottom_ids = torch.tensor(
|
||||
geometry["bottom_boundary_ids"], dtype=torch.long
|
||||
)
|
||||
top_ids = torch.tensor(geometry["top_boundary_ids"], dtype=torch.long)
|
||||
left_ids = torch.tensor(geometry["left_boundary_ids"], dtype=torch.long)
|
||||
right_ids = torch.tensor(
|
||||
geometry["right_boundary_ids"], dtype=torch.long
|
||||
)
|
||||
|
||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
||||
edge_attr = torch.cat(
|
||||
|
||||
Reference in New Issue
Block a user