add refined geometries in datamodule
This commit is contained in:
@@ -35,53 +35,36 @@ class GraphDataModule(LightningDataModule):
|
|||||||
self.geometry = load_dataset(self.hf_repo, name="geometry")[
|
self.geometry = load_dataset(self.hf_repo, name="geometry")[
|
||||||
self.split_name
|
self.split_name
|
||||||
]
|
]
|
||||||
edge_index = torch.tensor(
|
|
||||||
self.geometry["edge_index"][0], dtype=torch.int64
|
|
||||||
)
|
|
||||||
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
|
|
||||||
:, :2
|
|
||||||
]
|
|
||||||
|
|
||||||
bottom_ids = torch.tensor(
|
|
||||||
self.geometry["bottom_boundary_ids"][0], dtype=torch.long
|
|
||||||
)
|
|
||||||
top_ids = torch.tensor(
|
|
||||||
self.geometry["top_boundary_ids"][0], dtype=torch.long
|
|
||||||
)
|
|
||||||
left_ids = torch.tensor(
|
|
||||||
self.geometry["left_boundary_ids"][0], dtype=torch.long
|
|
||||||
)
|
|
||||||
right_ids = torch.tensor(
|
|
||||||
self.geometry["right_boundary_ids"][0], dtype=torch.long
|
|
||||||
)
|
|
||||||
self.data = [
|
self.data = [
|
||||||
self._build_dataset(
|
self._build_dataset(snapshot, geometry)
|
||||||
snapshot,
|
for snapshot, geometry in tqdm(
|
||||||
edge_index.T,
|
zip(hf_dataset, self.geometry),
|
||||||
pos,
|
desc="Building graphs",
|
||||||
bottom_ids,
|
total=len(hf_dataset),
|
||||||
top_ids,
|
|
||||||
left_ids,
|
|
||||||
right_ids,
|
|
||||||
)
|
)
|
||||||
for snapshot in tqdm(hf_dataset, desc="Building graphs")
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_dataset(
|
def _build_dataset(
|
||||||
self,
|
self,
|
||||||
snapshot: dict,
|
snapshot: dict,
|
||||||
edge_index: torch.Tensor,
|
geometry: dict,
|
||||||
pos: torch.Tensor,
|
|
||||||
bottom_ids: torch.Tensor,
|
|
||||||
top_ids: torch.Tensor,
|
|
||||||
left_ids: torch.Tensor,
|
|
||||||
right_ids: torch.Tensor,
|
|
||||||
) -> Data:
|
) -> Data:
|
||||||
conductivity = torch.tensor(
|
conductivity = torch.tensor(
|
||||||
snapshot["conductivity"], dtype=torch.float32
|
snapshot["conductivity"], dtype=torch.float32
|
||||||
)
|
)
|
||||||
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
|
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
|
||||||
|
|
||||||
|
edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T
|
||||||
|
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||||
|
bottom_ids = torch.tensor(
|
||||||
|
geometry["bottom_boundary_ids"], dtype=torch.long
|
||||||
|
)
|
||||||
|
top_ids = torch.tensor(geometry["top_boundary_ids"], dtype=torch.long)
|
||||||
|
left_ids = torch.tensor(geometry["left_boundary_ids"], dtype=torch.long)
|
||||||
|
right_ids = torch.tensor(
|
||||||
|
geometry["right_boundary_ids"], dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||||
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
||||||
edge_attr = torch.cat(
|
edge_attr = torch.cat(
|
||||||
|
|||||||
@@ -106,10 +106,13 @@ class ConditionalGNOBlock(MessagePassing):
|
|||||||
def message(self, x_i, x_j, c_i, c_j, edge_attr):
|
def message(self, x_i, x_j, c_i, c_j, edge_attr):
|
||||||
c_ij = 0.5 * (c_i + c_j)
|
c_ij = 0.5 * (c_i + c_j)
|
||||||
alpha = torch.sigmoid(self.balancing)
|
alpha = torch.sigmoid(self.balancing)
|
||||||
m = alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j)
|
gate = torch.sigmoid(self.edge_attr_net(edge_attr))
|
||||||
|
m = (
|
||||||
|
alpha * self.diff_net(x_j - x_i)
|
||||||
|
+ (1 - alpha) * self.x_net(x_j) * gate
|
||||||
|
)
|
||||||
m = m * self.c_ij_net(c_ij)
|
m = m * self.c_ij_net(c_ij)
|
||||||
gate = self.edge_attr_net(edge_attr)
|
return m
|
||||||
return m * torch.sigmoid(gate)
|
|
||||||
|
|
||||||
def update(self, aggr_out, x):
|
def update(self, aggr_out, x):
|
||||||
return x + self.alpha * self.msg_proj(aggr_out)
|
return x + self.alpha * self.msg_proj(aggr_out)
|
||||||
|
|||||||
Reference in New Issue
Block a user