random changes

This commit is contained in:
Filippo Olivo
2025-10-27 10:23:13 +01:00
parent f49817ca1e
commit 6e90ef5393
7 changed files with 325 additions and 80 deletions

View File

@@ -6,7 +6,6 @@ from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected
from .mesh_data import MeshData
import os
class GraphDataModule(LightningDataModule):
@@ -18,7 +17,7 @@ class GraphDataModule(LightningDataModule):
val_size: float = 0.1,
test_size: float = 0.1,
batch_size: int = 32,
remove_boundary_edges: bool = True,
remove_boundary_edges: bool = False,
):
super().__init__()
self.hf_repo = hf_repo
@@ -82,6 +81,7 @@ class GraphDataModule(LightningDataModule):
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)
edge_index = torch.tensor(geometry["edge_index"], dtype=torch.int64).T
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
bottom_ids = torch.tensor(
geometry["bottom_boundary_ids"], dtype=torch.long
@@ -97,7 +97,6 @@ class GraphDataModule(LightningDataModule):
boundary_mask, boundary_values = self._compute_boundary_mask(
bottom_ids, right_ids, top_ids, left_ids, temperature
)
if self.remove_boundary_edges:
boundary_idx = torch.unique(boundary_mask)
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
@@ -119,7 +118,7 @@ class GraphDataModule(LightningDataModule):
edge_attr=edge_attr,
y=temperature.unsqueeze(-1),
boundary_mask=boundary_mask,
boundary_values=torch.tensor(0), # Fake value (to fix)
boundary_values=boundary_values,
)
return MeshData(
@@ -129,7 +128,7 @@ class GraphDataModule(LightningDataModule):
pos=pos,
edge_attr=edge_attr,
boundary_mask=boundary_mask,
boundary_values=boundary_values.unsqueeze(-1),
boundary_values=boundary_values,
y=temperature.unsqueeze(-1),
)