random changes
This commit is contained in:
@@ -6,7 +6,6 @@ from torch_geometric.data import Data
|
||||
from torch_geometric.loader import DataLoader
|
||||
from torch_geometric.utils import to_undirected
|
||||
from .mesh_data import MeshData
|
||||
import os
|
||||
|
||||
|
||||
class GraphDataModule(LightningDataModule):
|
||||
@@ -18,7 +17,7 @@ class GraphDataModule(LightningDataModule):
|
||||
val_size: float = 0.1,
|
||||
test_size: float = 0.1,
|
||||
batch_size: int = 32,
|
||||
remove_boundary_edges: bool = True,
|
||||
remove_boundary_edges: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hf_repo = hf_repo
|
||||
@@ -82,6 +81,7 @@ class GraphDataModule(LightningDataModule):
|
||||
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
|
||||
|
||||
edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T
|
||||
|
||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||
bottom_ids = torch.tensor(
|
||||
geometry["bottom_boundary_ids"], dtype=torch.long
|
||||
@@ -97,7 +97,6 @@ class GraphDataModule(LightningDataModule):
|
||||
boundary_mask, boundary_values = self._compute_boundary_mask(
|
||||
bottom_ids, right_ids, top_ids, left_ids, temperature
|
||||
)
|
||||
|
||||
if self.remove_boundary_edges:
|
||||
boundary_idx = torch.unique(boundary_mask)
|
||||
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
||||
@@ -119,7 +118,7 @@ class GraphDataModule(LightningDataModule):
|
||||
edge_attr=edge_attr,
|
||||
y=temperature.unsqueeze(-1),
|
||||
boundary_mask=boundary_mask,
|
||||
boundary_values=torch.tensor(0), # Fake value (to fix)
|
||||
boundary_values=boundary_values,
|
||||
)
|
||||
|
||||
return MeshData(
|
||||
@@ -129,7 +128,7 @@ class GraphDataModule(LightningDataModule):
|
||||
pos=pos,
|
||||
edge_attr=edge_attr,
|
||||
boundary_mask=boundary_mask,
|
||||
boundary_values=boundary_values.unsqueeze(-1),
|
||||
boundary_values=boundary_values,
|
||||
y=temperature.unsqueeze(-1),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user