add refined geometries in datamodule

This commit is contained in:
Filippo Olivo
2025-10-02 12:20:22 +02:00
parent 1be98591d4
commit b26403189f
2 changed files with 23 additions and 37 deletions

View File

@@ -35,53 +35,36 @@ class GraphDataModule(LightningDataModule):
self.geometry = load_dataset(self.hf_repo, name="geometry")[
self.split_name
]
edge_index = torch.tensor(
self.geometry["edge_index"][0], dtype=torch.int64
)
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
:, :2
]
bottom_ids = torch.tensor(
self.geometry["bottom_boundary_ids"][0], dtype=torch.long
)
top_ids = torch.tensor(
self.geometry["top_boundary_ids"][0], dtype=torch.long
)
left_ids = torch.tensor(
self.geometry["left_boundary_ids"][0], dtype=torch.long
)
right_ids = torch.tensor(
self.geometry["right_boundary_ids"][0], dtype=torch.long
)
self.data = [
self._build_dataset(
snapshot,
edge_index.T,
pos,
bottom_ids,
top_ids,
left_ids,
right_ids,
self._build_dataset(snapshot, geometry)
for snapshot, geometry in tqdm(
zip(hf_dataset, self.geometry),
desc="Building graphs",
total=len(hf_dataset),
)
for snapshot in tqdm(hf_dataset, desc="Building graphs")
]
def _build_dataset(
self,
snapshot: dict,
edge_index: torch.Tensor,
pos: torch.Tensor,
bottom_ids: torch.Tensor,
top_ids: torch.Tensor,
left_ids: torch.Tensor,
right_ids: torch.Tensor,
geometry: dict,
) -> Data:
conductivity = torch.tensor(
snapshot["conductivity"], dtype=torch.float32
)
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
bottom_ids = torch.tensor(
geometry["bottom_boundary_ids"], dtype=torch.long
)
top_ids = torch.tensor(geometry["top_boundary_ids"], dtype=torch.long)
left_ids = torch.tensor(geometry["left_boundary_ids"], dtype=torch.long)
right_ids = torch.tensor(
geometry["right_boundary_ids"], dtype=torch.long
)
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
edge_attr = torch.cat(