minor
This commit is contained in:
@@ -40,6 +40,9 @@ class GraphDataModule(LightningDataModule):
|
||||
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
|
||||
:, :2
|
||||
]
|
||||
bottom_boundary_ids = torch.tensor(
|
||||
self.geometry["bottom_boundary_ids"][0], dtype=torch.int64
|
||||
)
|
||||
self.data = [
|
||||
self._build_dataset(
|
||||
torch.tensor(snapshot["conductivity"], dtype=torch.float32),
|
||||
@@ -47,6 +50,7 @@ class GraphDataModule(LightningDataModule):
|
||||
torch.tensor(snapshot["temperature"], dtype=torch.float32),
|
||||
edge_index.T,
|
||||
pos,
|
||||
bottom_boundary_ids,
|
||||
)
|
||||
for snapshot in tqdm(hf_dataset, desc="Building graphs")
|
||||
]
|
||||
@@ -58,13 +62,15 @@ class GraphDataModule(LightningDataModule):
|
||||
temperature: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
bottom_boundary_ids: torch.Tensor,
|
||||
) -> Data:
|
||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
||||
edge_attr = torch.cat(
|
||||
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
||||
)
|
||||
|
||||
boundary_temperature = boundary_vales[bottom_boundary_ids].max()
|
||||
boundary_vales[bottom_boundary_ids] = 1.0
|
||||
return Data(
|
||||
x=boundary_vales.unsqueeze(-1),
|
||||
c=conductivity.unsqueeze(-1),
|
||||
@@ -72,6 +78,7 @@ class GraphDataModule(LightningDataModule):
|
||||
pos=pos,
|
||||
edge_attr=edge_attr,
|
||||
y=temperature.unsqueeze(-1),
|
||||
boundary_temperature=boundary_vales[bottom_boundary_ids].max(),
|
||||
)
|
||||
|
||||
def setup(self, stage: str = None):
|
||||
|
||||
Reference in New Issue
Block a user