Files
thermal-conduction-ml/ThermalSolver/model/learnable_finite_difference.py
2025-11-18 21:55:54 +01:00

54 lines
1.6 KiB
Python

import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
# from torch.nn.utils import spectral_norm
class GCNConvLayer(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__("add")
self.lin = nn.Sequential(
nn.Linear(in_channels, out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels),
nn.ReLU(),
)
def _compute_edge_weight(self, edge_index, edge_w, deg):
""" """
return edge_w.squeeze() / (
1 + torch.sqrt(deg[edge_index[0]] * deg[edge_index[1]])
)
def forward(self, x, edge_index, edge_attr, deg):
edge_w = self._compute_edge_weight(edge_index, edge_attr, deg)
return self.propagate(edge_index, x=x, edge_weight=edge_w, deg=deg)
def message(self, x_j, edge_weight):
return edge_weight.view(-1, 1) * x_j
class CorrectionNet(nn.Module):
def __init__(self, hidden_dim=8):
super().__init__()
self.enc = nn.Sequential(
nn.Linear(1, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, hidden_dim),
nn.ReLU(),
)
self.model = GCNConvLayer(hidden_dim, hidden_dim)
self.dec = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.ReLU(),
)
def forward(self, x, edge_index, edge_attr, deg):
h = self.enc(x)
h = self.model(h, edge_index, edge_attr, deg)
out = self.dec(h)
return out