From 720931b831aef0001b4840ed38385a399c92ea94 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 14 Oct 2025 10:02:39 +0200 Subject: [PATCH] update model, model and datamodule --- ThermalSolver/data_module.py | 12 +----------- ThermalSolver/model/local_gno.py | 13 +++---------- ThermalSolver/module.py | 13 +++---------- ThermalSolver/normalizer.py | 26 +++++++++----------------- 4 files changed, 16 insertions(+), 48 deletions(-) diff --git a/ThermalSolver/data_module.py b/ThermalSolver/data_module.py index b3a0a73..a8289af 100644 --- a/ThermalSolver/data_module.py +++ b/ThermalSolver/data_module.py @@ -35,14 +35,6 @@ class GraphDataModule(LightningDataModule): def prepare_data(self): dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name] geometry = load_dataset(self.hf_repo, name="geometry")[self.split_name] - # data = [ - # self._build_dataset(snapshot, geometry) - # for snapshot, geometry in tqdm( - # zip(hf_dataset, self.geometry), - # desc="Building graphs", - # total=len(hf_dataset), - # ) - # ] total_len = len(dataset) train_len = int(self.train_size * total_len) @@ -127,7 +119,7 @@ class GraphDataModule(LightningDataModule): pos=pos, edge_attr=edge_attr, y=temperature.unsqueeze(-1), - boundary_mask=torch.tensor(0), # Fake value (to fix) + boundary_mask=boundary_mask, boundary_values=torch.tensor(0), # Fake value (to fix) ) @@ -143,8 +135,6 @@ class GraphDataModule(LightningDataModule): ) def setup(self, stage: str = None): - print(type(self.dataset_dict["train"])) - if stage == "fit" or stage is None: self.train_data = [ self._build_dataset(snap, geom) diff --git a/ThermalSolver/model/local_gno.py b/ThermalSolver/model/local_gno.py index d4e1099..7e2e022 100644 --- a/ThermalSolver/model/local_gno.py +++ b/ThermalSolver/model/local_gno.py @@ -77,13 +77,6 @@ class ConditionalGNOBlock(MessagePassing): nn.GELU(), ) - self.x_net = nn.Sequential( - nn.Linear(hidden_ch, hidden_ch * 2), - nn.GELU(), - nn.Linear(hidden_ch * 2, hidden_ch), - nn.GELU(), - ) - self.c_ij_net = nn.Sequential( nn.Linear(hidden_ch, hidden_ch // 2), nn.GELU(), @@ -116,9 +109,6 @@ class ConditionalGNOBlock(MessagePassing): c_ij = 0.5 * (c_i + c_j) gamma = self.gamma_net(torch.cat([x_i, x_j], dim=-1)) gate = self.edge_attr_net(edge_attr) - m = ( - gamma * self.diff_net(x_j - x_i) + (1 - gamma) * self.x_net(x_j) - ) * gate m = self.diff_net(x_j - x_i) * gate m = m * self.c_ij_net(c_ij) return m @@ -158,17 +148,20 @@ class GatingGNO(nn.Module): plot_results=False, batch=None, pos=None, + boundary_mask=None, ): x = self.encoder_x(x) c = self.encoder_c(c) if plot_results: x_ = self.dec(x) plot_results_fn(x_, pos, 0, batch=batch) + bc = x[boundary_mask] for _ in range(1, unrolling_steps + 1): for i, blk in enumerate(self.blocks): x = blk(x, c, edge_index, edge_attr=edge_attr) if plot_results: x_ = self.dec(x) + assert bc == x[boundary_mask] plot_results_fn(x_, pos, i * _, batch=batch) return self.dec(x) diff --git a/ThermalSolver/module.py b/ThermalSolver/module.py index 0329db5..b530310 100644 --- a/ThermalSolver/module.py +++ b/ThermalSolver/module.py @@ -28,20 +28,18 @@ class GraphSolver(LightningModule): self, x: torch.Tensor, c: torch.Tensor, - boundary: torch.Tensor, - boundary_mask: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, unrolling_steps: int = None, + boundary_mask: torch.Tensor = None, ): return self.model( x, c, - boundary, - boundary_mask, edge_index, edge_attr, unrolling_steps, + boundary_mask=boundary_mask, ) def _compute_loss(self, x, y): @@ -66,11 +64,10 @@ class GraphSolver(LightningModule): y_pred = self( x, c, - batch.boundary_values, - batch.boundary_mask, edge_index=edge_index, edge_attr=edge_attr, unrolling_steps=self.unrolling_steps, + boundary_mask=batch.boundary_mask, ) loss = self.loss(y_pred, y) boundary_loss = self.loss( @@ -85,8 +82,6 @@ class GraphSolver(LightningModule): y_pred = self( x, c, - batch.boundary_values, - batch.boundary_mask, edge_index=edge_index, edge_attr=edge_attr, unrolling_steps=self.unrolling_steps, @@ -104,8 +99,6 @@ class GraphSolver(LightningModule): y_pred = self.model( x, c, - batch.boundary_values, - batch.boundary_mask, edge_index=edge_index, edge_attr=edge_attr, unrolling_steps=self.unrolling_steps, diff --git a/ThermalSolver/normalizer.py b/ThermalSolver/normalizer.py index 4a8b62e..145572f 100644 --- a/ThermalSolver/normalizer.py +++ b/ThermalSolver/normalizer.py @@ -4,8 +4,7 @@ from torch_geometric.data import Data D_IN_KEYS = "x" D_ATTR_KEYS = ["c", "edge_attr"] D_OUT_KEY = "y" -D_KEYS = [D_IN_KEYS] + [D_OUT_KEY] + D_ATTR_KEYS -D_BOUNDS_KEYS = "boundary_temperatures" +D_KEYS = D_ATTR_KEYS + [D_OUT_KEY] class Normalizer: @@ -28,24 +27,17 @@ class Normalizer: std[key] = tmp.std(dim=0, keepdim=True) + 1e-6 return mean, std - def normalize(self, data): + @staticmethod + def _apply_input_boundary(data: Data): + bc = data.y[data.boundary_mask] + data[D_IN_KEYS][data.boundary_mask] = bc + + def normalize(self, data: list[Data]): for d in data: for key in D_KEYS: - if not hasattr(d, key): - raise AttributeError(f"Manca '{key}' in uno dei Data.") d[key] = (d[key] - self.mean[key]) / self.std[key] - self._recompute_boundary_temperatures(data) - - def _recompute_boundary_temperatures(self, data): - for d in data: - bottom_bc = d.y[d.bottom_boundary_ids].median() - top_bc = d.y[d.top_boundary_ids].median() - left_bc = d.y[d.left_boundary_ids].median() - right_bc = d.y[d.right_boundary_ids].median() - boundaries_temperatures = torch.tensor( - [bottom_bc, right_bc, top_bc, left_bc], dtype=torch.float32 - ) - d.boundary_temperatures = boundaries_temperatures.unsqueeze(0) + self._apply_input_boundary(d) + return data def denormalize(self, y: torch.tensor): return y * self.std[D_OUT_KEY] + self.mean[D_OUT_KEY]