From c6c416e682ab06774159bfb09d702f95b191a699 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 26 Sep 2025 09:10:02 +0200 Subject: [PATCH] minor --- ThermalSolver/data_module.py | 9 ++++++++- ThermalSolver/module.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/ThermalSolver/data_module.py b/ThermalSolver/data_module.py index 905575a..f57590f 100644 --- a/ThermalSolver/data_module.py +++ b/ThermalSolver/data_module.py @@ -40,6 +40,9 @@ class GraphDataModule(LightningDataModule): pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[ :, :2 ] + bottom_boundary_ids = torch.tensor( + self.geometry["bottom_boundary_ids"][0], dtype=torch.int64 + ) self.data = [ self._build_dataset( torch.tensor(snapshot["conductivity"], dtype=torch.float32), @@ -47,6 +50,7 @@ class GraphDataModule(LightningDataModule): torch.tensor(snapshot["temperature"], dtype=torch.float32), edge_index.T, pos, + bottom_boundary_ids, ) for snapshot in tqdm(hf_dataset, desc="Building graphs") ] @@ -58,13 +62,15 @@ class GraphDataModule(LightningDataModule): temperature: torch.Tensor, edge_index: torch.Tensor, pos: torch.Tensor, + bottom_boundary_ids: torch.Tensor, ) -> Data: edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) edge_attr = pos[edge_index[0]] - pos[edge_index[1]] edge_attr = torch.cat( [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 ) - + boundary_temperature = boundary_vales[bottom_boundary_ids].max() + boundary_vales[bottom_boundary_ids] = 1.0 return Data( x=boundary_vales.unsqueeze(-1), c=conductivity.unsqueeze(-1), @@ -72,6 +78,7 @@ class GraphDataModule(LightningDataModule): pos=pos, edge_attr=edge_attr, y=temperature.unsqueeze(-1), + boundary_temperature=boundary_vales[bottom_boundary_ids].max(), ) def setup(self, stage: str = None): diff --git a/ThermalSolver/module.py b/ThermalSolver/module.py index 83de199..2acd94f 100644 --- a/ThermalSolver/module.py +++ b/ThermalSolver/module.py @@ -50,7 +50,10 @@ class GraphSolver(LightningModule): for _ in range(self.unrolling_steps): x_prev = x.detach() x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr) - loss += self.loss(x, y) + actual_loss = self.loss(x, y) + loss += actual_loss + print(f"Train step loss: {actual_loss.item()}") + self._log_loss(loss, batch, "train") return loss @@ -78,3 +81,7 @@ class GraphSolver(LightningModule): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer + + def scale_bc(self, data: Batch, y: torch.Tensor): + t = data.boundary_temperature[data.batch] + return y * t