fix module and model + add curriculum callback

This commit is contained in:
Filippo Olivo
2025-12-09 09:18:36 +01:00
parent 2935785b31
commit f2ce282a68
4 changed files with 243 additions and 111 deletions

View File

@@ -7,44 +7,13 @@ from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected
from .mesh_data import MeshData
# from torch.utils.data import Dataset
from torch_geometric.utils import scatter
def compute_nodal_area(edge_index, edge_attr, num_nodes):
"""
1. Calculates Area ~ (Min Edge Length)^2
2. Scales by Mean so average cell has size 1.0
"""
row, col = edge_index
dist = edge_attr.squeeze()
# 1. Get 'h' (Closest neighbor distance)
# Using 'min' filters out diagonal connections in the quad mesh
h = scatter(dist, col, dim=0, dim_size=num_nodes, reduce="min")
# 2. Estimate Raw Area
raw_area = h.pow(2)
# 3. Mean Scaling (The Best Normalization)
# This keeps values near 1.0, preserving stability AND physics ratios.
# We detach to ensure no gradients flow here (it's static data).
mean_val = raw_area.mean().detach()
# Result:
# Small cells -> approx 0.1
# Large cells -> approx 5.0
# Average -> 1.0
# nodal_area = (raw_area / mean_val).unsqueeze(-1) + 1e-6
nodal_area = raw_area
return nodal_area.unsqueeze(-1)
class GraphDataModule(LightningDataModule):
def __init__(
self,
hf_repo: str,
split_name: str,
n_elements: int = None,
train_size: float = 0.2,
val_size: float = 0.1,
test_size: float = 0.1,
@@ -52,18 +21,19 @@ class GraphDataModule(LightningDataModule):
remove_boundary_edges: bool = False,
build_radial_graph: bool = False,
radius: float = None,
start_unrolling_steps: int = 1,
unrolling_steps: int = 1,
):
super().__init__()
self.hf_repo = hf_repo
self.split_name = split_name
self.n_elements = n_elements
self.dataset_dict = {}
self.train_dataset, self.val_dataset, self.test_dataset = (
None,
None,
None,
)
self.unrolling_steps = start_unrolling_steps
self.unrolling_steps = unrolling_steps
self.geometry_dict = {}
self.train_size = train_size
self.val_size = val_size
@@ -76,6 +46,9 @@ class GraphDataModule(LightningDataModule):
def prepare_data(self):
dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name]
geometry = load_dataset(self.hf_repo, name="geometry")[self.split_name]
if self.n_elements is not None:
dataset = dataset.select(range(self.n_elements))
geometry = geometry.select(range(self.n_elements))
total_len = len(dataset)
train_len = int(self.train_size * total_len)
@@ -117,13 +90,18 @@ class GraphDataModule(LightningDataModule):
self,
snapshot: dict,
geometry: dict,
test: bool = False,
) -> Data:
conductivity = torch.tensor(
geometry["conductivity"], dtype=torch.float32
)
temperatures = torch.tensor(
snapshot["temperatures"], dtype=torch.float32
)[:40]
temperatures = (
torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:40]
if not test
else torch.tensor(snapshot["temperatures"], dtype=torch.float32)[
: self.unrolling_steps + 1
]
)
times = torch.tensor(snapshot["times"], dtype=torch.float32)
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
@@ -138,16 +116,6 @@ class GraphDataModule(LightningDataModule):
)
if self.build_radial_graph:
# from pina.graph import RadiusGraph
# if self.radius is None:
# raise ValueError("Radius must be specified for radial graph.")
# edge_index = RadiusGraph.compute_radius_graph(
# pos, radius=self.radius
# )
# from torch_geometric.utils import remove_self_loops
# edge_index, _ = remove_self_loops(edge_index)
raise NotImplementedError(
"Radial graph building not implemented yet."
)
@@ -161,7 +129,6 @@ class GraphDataModule(LightningDataModule):
bottom_ids, right_ids, top_ids, left_ids, temperatures[0, :]
)
edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
nodal_area = compute_nodal_area(edge_index, edge_attr, pos.size(0))
if self.remove_boundary_edges:
boundary_idx = torch.unique(boundary_mask)
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
@@ -186,7 +153,6 @@ class GraphDataModule(LightningDataModule):
edge_attr=edge_attr,
boundary_mask=boundary_mask,
boundary_values=boundary_values,
nodal_area=nodal_area,
)
)
return data
@@ -213,7 +179,7 @@ class GraphDataModule(LightningDataModule):
]
if stage == "test" or stage is None:
self.test_data = [
self._build_dataset(snap, geom)
self._build_dataset(snap, geom, test=True)
for snap, geom in tqdm(
zip(self.dataset_dict["test"], self.geometry_dict["test"]),
desc="Building test graphs",
@@ -234,7 +200,9 @@ class GraphDataModule(LightningDataModule):
# self.train_dataset = ds
# print(type(self.train_data[0]))
ds = [i for data in self.train_data for i in data]
# print(type(ds[0]))
print(
f"\nLoading training data, using {self.unrolling_steps} unrolling steps..."
)
return DataLoader(
ds,
batch_size=self.batch_size,
@@ -244,6 +212,9 @@ class GraphDataModule(LightningDataModule):
)
def val_dataloader(self):
print(
f"\nLoading validation data, using {self.unrolling_steps} unrolling steps..."
)
ds = [i for data in self.val_data for i in data]
return DataLoader(
ds,
@@ -254,12 +225,10 @@ class GraphDataModule(LightningDataModule):
)
def test_dataloader(self):
ds = self.create_autoregressive_datasets(
dataset="test", no_unrolling=True
)
ds = [i for data in self.test_data for i in data]
return DataLoader(
ds,
batch_size=self.batch_size,
batch_size=1,
shuffle=False,
num_workers=8,
pin_memory=True,