From 7a6fbdb89c621a48faeadfabfeb9fb03c5ed225f Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Fri, 3 Oct 2025 10:42:05 +0200 Subject: [PATCH] implement new model --- ThermalSolver/data_module.py | 64 +++++++++++++++++++++++--------- ThermalSolver/model/local_gno.py | 31 ++++++++-------- ThermalSolver/module.py | 39 +++++++++---------- 3 files changed, 80 insertions(+), 54 deletions(-) diff --git a/ThermalSolver/data_module.py b/ThermalSolver/data_module.py index a65ed22..f500d0f 100644 --- a/ThermalSolver/data_module.py +++ b/ThermalSolver/data_module.py @@ -17,6 +17,7 @@ class GraphDataModule(LightningDataModule): val_size: float = 0.1, test_size: float = 0.1, batch_size: int = 32, + remove_boundary_edges: bool = True, ): super().__init__() self.hf_repo = hf_repo @@ -27,6 +28,7 @@ class GraphDataModule(LightningDataModule): self.val_size = val_size self.test_size = test_size self.batch_size = batch_size + self.remove_boundary_edges = remove_boundary_edges def prepare_data(self): hf_dataset = load_dataset(self.hf_repo, name="snapshots")[ @@ -44,6 +46,28 @@ class GraphDataModule(LightningDataModule): ) ] + def _compute_boundary_mask( + self, bottom_ids, right_ids, top_ids, left_ids, temperature + ): + left_ids = left_ids[~torch.isin(left_ids, bottom_ids)] + right_ids = right_ids[~torch.isin(right_ids, bottom_ids)] + left_ids = left_ids[~torch.isin(left_ids, top_ids)] + right_ids = right_ids[~torch.isin(right_ids, top_ids)] + + bottom_bc = temperature[bottom_ids].median() + bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc + left_bc = temperature[left_ids].median() + left_bc_mask = torch.ones(len(left_ids)) * left_bc + right_bc = temperature[right_ids].median() + right_bc_mask = torch.ones(len(right_ids)) * right_bc + + boundary_values = torch.cat( + [bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0 + ) + boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0) + + return boundary_mask, boundary_values + def _build_dataset( self, snapshot: dict, @@ -66,27 +90,34 @@ class GraphDataModule(LightningDataModule): ) edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) + + boundary_mask, boundary_values = self._compute_boundary_mask( + bottom_ids, right_ids, top_ids, left_ids, temperature + ) + + if self.remove_boundary_edges: + boundary_idx = torch.unique(boundary_mask) + edge_index_mask = ~torch.isin(edge_index[1], boundary_idx) + edge_index = edge_index[:, edge_index_mask] + edge_attr = pos[edge_index[0]] - pos[edge_index[1]] edge_attr = torch.cat( [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 ) - left_ids = left_ids[~torch.isin(left_ids, bottom_ids)] - right_ids = right_ids[~torch.isin(right_ids, bottom_ids)] - left_ids = left_ids[~torch.isin(left_ids, top_ids)] - right_ids = right_ids[~torch.isin(right_ids, top_ids)] - - bottom_bc = temperature[bottom_ids].median() - bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc - left_bc = temperature[left_ids].median() - left_bc_mask = torch.ones(len(left_ids)) * left_bc - right_bc = temperature[right_ids].median() - right_bc_mask = torch.ones(len(right_ids)) * right_bc - - boundary_values = torch.cat( - [bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0 - ) - boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0) + x = torch.zeros_like(temperature, dtype=torch.float32).unsqueeze(-1) + if self.remove_boundary_edges: + x[boundary_mask] = boundary_values.unsqueeze(-1) + return MeshData( + x=x, + c=conductivity.unsqueeze(-1), + edge_index=edge_index, + pos=pos, + edge_attr=edge_attr, + y=temperature.unsqueeze(-1), + boundary_mask=boundary_mask, + boundary_values=torch.tensor(0), + ) return MeshData( x=torch.rand_like(temperature).unsqueeze(-1), @@ -110,7 +141,6 @@ class GraphDataModule(LightningDataModule): if stage == "test" or stage is None: self.test_data = self.data[val_end:] - # nel tuo LightningDataModule def train_dataloader(self): return DataLoader( self.train_data, batch_size=self.batch_size, shuffle=True diff --git a/ThermalSolver/model/local_gno.py b/ThermalSolver/model/local_gno.py index 7177038..e6bf568 100644 --- a/ThermalSolver/model/local_gno.py +++ b/ThermalSolver/model/local_gno.py @@ -4,10 +4,6 @@ from torch_geometric.nn import MessagePassing from matplotlib.tri import Triangulation -def _import_boundary_conditions(x, boundary, boundary_mask): - x[boundary_mask] = boundary - - def plot_results_fn(x, pos, i, batch): x = x[batch == 0] pos = pos[batch == 0] @@ -63,7 +59,6 @@ class DecX(nn.Module): class ConditionalGNOBlock(MessagePassing): def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): super().__init__(aggr=aggr, node_dim=0) - # self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch) self.edge_attr_net = nn.Sequential( nn.Linear(edge_ch, hidden_ch // 2), @@ -73,9 +68,9 @@ class ConditionalGNOBlock(MessagePassing): ) self.msg_proj = nn.Sequential( - nn.Linear(hidden_ch, hidden_ch), + nn.Linear(hidden_ch, hidden_ch, bias=False), nn.SiLU(), - nn.Linear(hidden_ch, hidden_ch), + nn.Linear(hidden_ch, hidden_ch, bias=False), ) self.diff_net = nn.Sequential( @@ -98,7 +93,15 @@ class ConditionalGNOBlock(MessagePassing): ) self.balancing = nn.Parameter(torch.tensor(0.0)) - self.alpha = nn.Parameter(torch.tensor(1.0)) + + self.alpha_net = nn.Sequential( + nn.Linear(2 * hidden_ch, hidden_ch), + nn.SiLU(), + nn.Linear(hidden_ch, hidden_ch // 2), + nn.SiLU(), + nn.Linear(hidden_ch // 2, 1), + nn.Sigmoid(), + ) def forward(self, x, c, edge_index, edge_attr=None): return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) @@ -108,14 +111,14 @@ class ConditionalGNOBlock(MessagePassing): alpha = torch.sigmoid(self.balancing) gate = torch.sigmoid(self.edge_attr_net(edge_attr)) m = ( - alpha * self.diff_net(x_j - x_i) - + (1 - alpha) * self.x_net(x_j) * gate - ) + alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j) + ) * gate m = m * self.c_ij_net(c_ij) return m def update(self, aggr_out, x): - return x + self.alpha * self.msg_proj(aggr_out) + alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1)) + return x + alpha * self.msg_proj(aggr_out) class GatingGNO(nn.Module): @@ -153,14 +156,10 @@ class GatingGNO(nn.Module): ): x = self.encoder_x(x) c = self.encoder_c(c) - boundary = self.encoder_x(boundary) if plot_results: - _import_boundary_conditions(x, boundary, boundary_mask) x_ = self.dec(x) plot_results_fn(x_, pos, 0, batch=batch) - for _ in range(1, unrolling_steps + 1): - _import_boundary_conditions(x, boundary, boundary_mask) for blk in self.blocks: x = blk(x, c, edge_index, edge_attr=edge_attr) if plot_results: diff --git a/ThermalSolver/module.py b/ThermalSolver/module.py index f9d9e2e..842cb69 100644 --- a/ThermalSolver/module.py +++ b/ThermalSolver/module.py @@ -1,31 +1,26 @@ import torch from lightning import LightningModule from torch_geometric.data import Batch -from matplotlib.tri import Triangulation +import importlib -# def plot_results(x, pos, step, i, batch): -# x = x[batch == 0] -# pos = pos[batch == 0] -# tria = Triangulation(pos[:, 0].cpu(), pos[:, 1].cpu()) -# import matplotlib.pyplot as plt - -# plt.tricontourf(tria, x[:, 0].cpu(), levels=14) -# plt.colorbar() -# plt.savefig(f"{step:03d}_out_{i:03d}.png") -# plt.axis("equal") -# plt.close() +def import_class(class_path: str): + module_path, class_name = class_path.rsplit(".", 1) # split last dot + module = importlib.import_module(module_path) # import the module + cls = getattr(module, class_name) # get the class + return cls class GraphSolver(LightningModule): def __init__( self, - model: torch.nn.Module, + model_class_path: str, + model_init_args: dict, loss: torch.nn.Module = None, unrolling_steps: int = 48, ): super().__init__() - self.model = model + self.model = import_class(model_class_path)(**model_init_args) self.loss = loss if loss is not None else torch.nn.MSELoss() self.unrolling_steps = unrolling_steps @@ -57,7 +52,7 @@ class GraphSolver(LightningModule): def _log_loss(self, loss, batch, stage: str): self.log( - f"{stage}_loss", + f"{stage}/loss", loss, on_step=False, on_epoch=True, @@ -68,8 +63,6 @@ class GraphSolver(LightningModule): def training_step(self, batch: Batch, _): x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - # x = self._impose_bc(x, batch) - # for _ in range(self.unrolling_steps): y_pred = self( x, c, @@ -79,9 +72,12 @@ class GraphSolver(LightningModule): edge_attr=edge_attr, unrolling_steps=self.unrolling_steps, ) - # x = self._impose_bc(x, batch) loss = self.loss(y_pred, y) + boundary_loss = self.loss( + y_pred[batch.boundary_mask], y[batch.boundary_mask] + ) self._log_loss(loss, batch, "train") + self._log_loss(boundary_loss, batch, "train_boundary") return loss def validation_step(self, batch: Batch, _): @@ -96,12 +92,15 @@ class GraphSolver(LightningModule): unrolling_steps=self.unrolling_steps, ) loss = self.loss(y_pred, y) + boundary_loss = self.loss( + y_pred[batch.boundary_mask], y[batch.boundary_mask] + ) self._log_loss(loss, batch, "val") + self._log_loss(boundary_loss, batch, "val_boundary") return loss def test_step(self, batch: Batch, _): x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) - # for _ in range(self.unrolling_steps): y_pred = self.model( x, c, @@ -114,8 +113,6 @@ class GraphSolver(LightningModule): batch=batch.batch, pos=batch.pos, ) - # x = self._impose_bc(x, batch) - # plot_results(x, batch.pos, self.global_step, _, batch.batch) loss = self._compute_loss(y_pred, y) self._log_loss(loss, batch, "test") return loss