fix module and model + add curriculum callback
This commit is contained in:
@@ -7,44 +7,13 @@ from torch_geometric.loader import DataLoader
|
||||
from torch_geometric.utils import to_undirected
|
||||
from .mesh_data import MeshData
|
||||
|
||||
# from torch.utils.data import Dataset
|
||||
from torch_geometric.utils import scatter
|
||||
|
||||
|
||||
def compute_nodal_area(edge_index, edge_attr, num_nodes):
|
||||
"""
|
||||
1. Calculates Area ~ (Min Edge Length)^2
|
||||
2. Scales by Mean so average cell has size 1.0
|
||||
"""
|
||||
row, col = edge_index
|
||||
dist = edge_attr.squeeze()
|
||||
|
||||
# 1. Get 'h' (Closest neighbor distance)
|
||||
# Using 'min' filters out diagonal connections in the quad mesh
|
||||
h = scatter(dist, col, dim=0, dim_size=num_nodes, reduce="min")
|
||||
|
||||
# 2. Estimate Raw Area
|
||||
raw_area = h.pow(2)
|
||||
|
||||
# 3. Mean Scaling (The Best Normalization)
|
||||
# This keeps values near 1.0, preserving stability AND physics ratios.
|
||||
# We detach to ensure no gradients flow here (it's static data).
|
||||
mean_val = raw_area.mean().detach()
|
||||
|
||||
# Result:
|
||||
# Small cells -> approx 0.1
|
||||
# Large cells -> approx 5.0
|
||||
# Average -> 1.0
|
||||
# nodal_area = (raw_area / mean_val).unsqueeze(-1) + 1e-6
|
||||
nodal_area = raw_area
|
||||
return nodal_area.unsqueeze(-1)
|
||||
|
||||
|
||||
class GraphDataModule(LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
hf_repo: str,
|
||||
split_name: str,
|
||||
n_elements: int = None,
|
||||
train_size: float = 0.2,
|
||||
val_size: float = 0.1,
|
||||
test_size: float = 0.1,
|
||||
@@ -52,18 +21,19 @@ class GraphDataModule(LightningDataModule):
|
||||
remove_boundary_edges: bool = False,
|
||||
build_radial_graph: bool = False,
|
||||
radius: float = None,
|
||||
start_unrolling_steps: int = 1,
|
||||
unrolling_steps: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.hf_repo = hf_repo
|
||||
self.split_name = split_name
|
||||
self.n_elements = n_elements
|
||||
self.dataset_dict = {}
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
self.unrolling_steps = start_unrolling_steps
|
||||
self.unrolling_steps = unrolling_steps
|
||||
self.geometry_dict = {}
|
||||
self.train_size = train_size
|
||||
self.val_size = val_size
|
||||
@@ -76,6 +46,9 @@ class GraphDataModule(LightningDataModule):
|
||||
def prepare_data(self):
|
||||
dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name]
|
||||
geometry = load_dataset(self.hf_repo, name="geometry")[self.split_name]
|
||||
if self.n_elements is not None:
|
||||
dataset = dataset.select(range(self.n_elements))
|
||||
geometry = geometry.select(range(self.n_elements))
|
||||
|
||||
total_len = len(dataset)
|
||||
train_len = int(self.train_size * total_len)
|
||||
@@ -117,13 +90,18 @@ class GraphDataModule(LightningDataModule):
|
||||
self,
|
||||
snapshot: dict,
|
||||
geometry: dict,
|
||||
test: bool = False,
|
||||
) -> Data:
|
||||
conductivity = torch.tensor(
|
||||
geometry["conductivity"], dtype=torch.float32
|
||||
)
|
||||
temperatures = torch.tensor(
|
||||
snapshot["temperatures"], dtype=torch.float32
|
||||
)[:40]
|
||||
temperatures = (
|
||||
torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:40]
|
||||
if not test
|
||||
else torch.tensor(snapshot["temperatures"], dtype=torch.float32)[
|
||||
: self.unrolling_steps + 1
|
||||
]
|
||||
)
|
||||
times = torch.tensor(snapshot["times"], dtype=torch.float32)
|
||||
|
||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||
@@ -138,16 +116,6 @@ class GraphDataModule(LightningDataModule):
|
||||
)
|
||||
|
||||
if self.build_radial_graph:
|
||||
# from pina.graph import RadiusGraph
|
||||
|
||||
# if self.radius is None:
|
||||
# raise ValueError("Radius must be specified for radial graph.")
|
||||
# edge_index = RadiusGraph.compute_radius_graph(
|
||||
# pos, radius=self.radius
|
||||
# )
|
||||
# from torch_geometric.utils import remove_self_loops
|
||||
|
||||
# edge_index, _ = remove_self_loops(edge_index)
|
||||
raise NotImplementedError(
|
||||
"Radial graph building not implemented yet."
|
||||
)
|
||||
@@ -161,7 +129,6 @@ class GraphDataModule(LightningDataModule):
|
||||
bottom_ids, right_ids, top_ids, left_ids, temperatures[0, :]
|
||||
)
|
||||
edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
|
||||
nodal_area = compute_nodal_area(edge_index, edge_attr, pos.size(0))
|
||||
if self.remove_boundary_edges:
|
||||
boundary_idx = torch.unique(boundary_mask)
|
||||
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
||||
@@ -186,7 +153,6 @@ class GraphDataModule(LightningDataModule):
|
||||
edge_attr=edge_attr,
|
||||
boundary_mask=boundary_mask,
|
||||
boundary_values=boundary_values,
|
||||
nodal_area=nodal_area,
|
||||
)
|
||||
)
|
||||
return data
|
||||
@@ -213,7 +179,7 @@ class GraphDataModule(LightningDataModule):
|
||||
]
|
||||
if stage == "test" or stage is None:
|
||||
self.test_data = [
|
||||
self._build_dataset(snap, geom)
|
||||
self._build_dataset(snap, geom, test=True)
|
||||
for snap, geom in tqdm(
|
||||
zip(self.dataset_dict["test"], self.geometry_dict["test"]),
|
||||
desc="Building test graphs",
|
||||
@@ -234,7 +200,9 @@ class GraphDataModule(LightningDataModule):
|
||||
# self.train_dataset = ds
|
||||
# print(type(self.train_data[0]))
|
||||
ds = [i for data in self.train_data for i in data]
|
||||
# print(type(ds[0]))
|
||||
print(
|
||||
f"\nLoading training data, using {self.unrolling_steps} unrolling steps..."
|
||||
)
|
||||
return DataLoader(
|
||||
ds,
|
||||
batch_size=self.batch_size,
|
||||
@@ -244,6 +212,9 @@ class GraphDataModule(LightningDataModule):
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
print(
|
||||
f"\nLoading validation data, using {self.unrolling_steps} unrolling steps..."
|
||||
)
|
||||
ds = [i for data in self.val_data for i in data]
|
||||
return DataLoader(
|
||||
ds,
|
||||
@@ -254,12 +225,10 @@ class GraphDataModule(LightningDataModule):
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
ds = self.create_autoregressive_datasets(
|
||||
dataset="test", no_unrolling=True
|
||||
)
|
||||
ds = [i for data in self.test_data for i in data]
|
||||
return DataLoader(
|
||||
ds,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
|
||||
Reference in New Issue
Block a user