diff --git a/ThermalSolver/data_module.py b/ThermalSolver/data_module.py index 3b65f3a..1b74f16 100644 --- a/ThermalSolver/data_module.py +++ b/ThermalSolver/data_module.py @@ -4,6 +4,7 @@ from lightning import LightningDataModule from datasets import load_dataset from torch_geometric.data import Data from torch_geometric.loader import DataLoader +from torch_geometric.utils import to_undirected class GraphDataModule(LightningDataModule): @@ -34,7 +35,7 @@ class GraphDataModule(LightningDataModule): self.split_name ] edge_index = torch.tensor( - self.geometry["edge_index"][0], dtype=torch.int32 + self.geometry["edge_index"][0], dtype=torch.int64 ) pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[ :, :2 @@ -51,23 +52,29 @@ class GraphDataModule(LightningDataModule): ] def _build_dataset( - self, conductivity, boundary_vales, temperature, edge_index, pos - ): - input_ = torch.stack([conductivity, boundary_vales], dim=1) + self, + conductivity: torch.Tensor, + boundary_vales: torch.Tensor, + temperature: torch.Tensor, + edge_index: torch.Tensor, + pos: torch.Tensor, + ) -> Data: + edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) edge_attr = pos[edge_index[0]] - pos[edge_index[1]] edge_attr = torch.cat( [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 ) - + return Data( - x=input_, + x=boundary_vales.unsqueeze(-1), + c=conductivity.unsqueeze(-1), edge_index=edge_index, pos=pos, edge_attr=edge_attr, - y=temperature, + y=temperature.unsqueeze(-1), ) - def setup(self, stage=None): + def setup(self, stage: str = None): n = len(self.data) train_end = int(n * self.train_size) val_end = train_end + int(n * self.val_size) @@ -78,13 +85,13 @@ class GraphDataModule(LightningDataModule): if stage == "test" or stage is None: self.test_data = self.data[val_end:] - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: return DataLoader( self.train_data, batch_size=self.batch_size, shuffle=True ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: return DataLoader(self.val_data, batch_size=self.batch_size) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: return DataLoader(self.test_data, batch_size=self.batch_size) diff --git a/ThermalSolver/model/local_gno.py b/ThermalSolver/model/local_gno.py new file mode 100644 index 0000000..e9c575d --- /dev/null +++ b/ThermalSolver/model/local_gno.py @@ -0,0 +1,108 @@ +import torch +from torch import nn +from torch_geometric.nn import MessagePassing + + +# ---- FiLM that starts as identity and normalizes the target ---- +class FiLM(nn.Module): + def __init__(self, c_ch, h_ch): + super().__init__() + self.net = nn.Sequential( + nn.Linear(c_ch, 2*h_ch), + nn.SiLU(), + nn.Linear(2*h_ch, 2*h_ch) + ) + # init to identity: gamma≈0 (so 1+gamma=1), beta=0 + nn.init.zeros_(self.net[-1].weight) + nn.init.zeros_(self.net[-1].bias) + self.norm = nn.LayerNorm(h_ch) + + def forward(self, h, c): + gb = self.net(c) + gamma, beta = gb.chunk(2, dim=-1) + return (1 + gamma) * self.norm(h) + beta + + +class ConditionalGNOBlock(MessagePassing): + """ + Message passing with FiLM applied to the MESSAGE m_ij, + using edge context c_ij = (c_i + c_j)/2. + """ + def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): + super().__init__(aggr=aggr, node_dim=0) + self.pre_norm = nn.LayerNorm(hidden_ch) + + # raw message builder + self.msg = nn.Sequential( + nn.Linear(2*hidden_ch + edge_ch, 2*hidden_ch), + nn.SiLU(), + nn.Linear(2*hidden_ch, hidden_ch) + ) + + # FiLM over the message (per-edge) + self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch) + + # node update with residual + self.update_mlp = nn.Sequential( + nn.Linear(2*hidden_ch, hidden_ch), + nn.SiLU(), + nn.Linear(hidden_ch, hidden_ch) + ) + + def forward(self, x, c, edge_index, edge_attr=None): + # pre-norm helps stability + x_in = x + x = self.pre_norm(x) + m = self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) + out = self.update_mlp(torch.cat([x_in, m], dim=-1)) + return x_in + out # residual + + def message(self, x_i, x_j, c_i, c_j, edge_attr): + if edge_attr is not None: + m_in = torch.cat([x_i, x_j, edge_attr], dim=-1) + else: + m_in = torch.cat([x_i, x_j], dim=-1) + + m_raw = self.msg(m_in) + + # edge conditioning: simple mean + c_ctx = 0.5 * (c_i + c_j) + m = self.film_msg(m_raw, c_ctx) + return m + + +class GatingGNO(nn.Module): + """ + In: + x : [N, Cx] (e.g., u or features to predict from) + c : [N, Cf] (conditioning field, e.g., conductivity) + Out: + y : [N, out_ch] + """ + def __init__(self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1): + super().__init__() + self.encoder_x = nn.Sequential( + nn.Linear(x_ch_node, hidden // 2), + nn.SiLU(), + nn.Linear(hidden // 2, hidden), + ) + self.encoder_c = nn.Sequential( + nn.Linear(f_ch_node, hidden // 2), + nn.SiLU(), + nn.Linear(hidden // 2, hidden), + ) + self.blocks = nn.ModuleList( + [ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) for _ in range(layers)] + ) + self.dec = nn.Sequential( + nn.LayerNorm(hidden), + nn.SiLU(), + nn.Linear(hidden, out_ch) + ) + + def forward(self, x, c, edge_index, edge_attr=None): + x = self.encoder_x(x) # [N,H] + c = self.encoder_c(c) # [N,H] + for blk in self.blocks: + x = blk(x, c, edge_index, edge_attr=edge_attr) + return self.dec(x) \ No newline at end of file diff --git a/ThermalSolver/module.py b/ThermalSolver/module.py new file mode 100644 index 0000000..922d810 --- /dev/null +++ b/ThermalSolver/module.py @@ -0,0 +1,74 @@ +import torch +from lightning import LightningModule +from torch_geometric.data import Batch + + +class GraphSolver(LightningModule): + def __init__(self, model: torch.nn.Module, loss: torch.nn.Module = None, unrolling_steps: int = 10): + super().__init__() + self.model = model + self.loss = loss if loss is not None else torch.nn.MSELoss() + self.unrolling_steps = unrolling_steps + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + ): + return self.model(x, c, edge_index, edge_attr) + + def _compute_loss_train(self, x, x_prev, y): + return self.loss(x, y) + self.loss(x, x_prev) + + def _compute_loss(self, x, y): + return self.loss(x, y) + + def _preprocess_batch(self, batch: Batch): + return batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr + + def _log_loss(self, loss, batch, stage: str): + self.log( + f"{stage}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=int(batch.num_graphs), + ) + return loss + + def training_step(self, batch: Batch, _): + x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) + for _ in range(self.unrolling_steps): + x_prev = x.detach() + x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr) + loss = self.loss(x, y) + self._log_loss(loss, batch, "train") + return loss + + def validation_step(self, batch: Batch, _): + x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) + for _ in range(self.unrolling_steps): + x_prev = x.detach() + x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr) + loss = self.loss(x, x_prev) + if loss < 1e-5: + break + loss = self._compute_loss(x, y) + self._log_loss(loss, batch, "val") + return loss + + def test_step(self, batch: Batch, _): + x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) + for _ in range(self.unrolling_steps): + x_prev = x.detach() + x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr) + loss = self._compute_loss(x, y) + self._log_loss(loss, batch, "test") + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=5e-3) + return optimizer