Files
2025-12-19 15:50:26 +01:00

143 lines
4.5 KiB
Python

import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch.nn.utils import spectral_norm
class LogPhysEncoder(nn.Module):
"""
Processes 1/dx in log-space to handle multiple scales of geometry
(from micro-meshes to macro-meshes) without numerical instability.
"""
def __init__(self, hidden_dim):
super().__init__()
self.mlp = nn.Sequential(
spectral_norm(nn.Linear(1, hidden_dim)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, 1)),
nn.Softplus(), # Physical conductance must be positive
)
def forward(self, inv_dx):
# We use log(1/dx) to linearize the scale of different geometries
log_inv_dx = torch.log(inv_dx + 1e-9)
return self.mlp(log_inv_dx)
class DiffusionLayer(MessagePassing):
"""
Modella: T_new = T_old + dt * Divergenza(Flusso)
"""
def __init__(
self,
channels: int,
**kwargs,
):
super().__init__(aggr="add", **kwargs)
self.conductivity_net = nn.Sequential(
spectral_norm(nn.Linear(channels, channels, bias=False)),
nn.GELU(),
spectral_norm(nn.Linear(channels, channels, bias=False)),
)
self.phys_encoder = LogPhysEncoder(hidden_dim=channels)
self.alpha_param = nn.Parameter(torch.tensor(1e-2))
@property
def alpha(self):
return torch.clamp(self.alpha_param, min=1e-7, max=1.0)
def forward(self, x, edge_index, edge_weight, conductivity):
edge_weight = edge_weight.unsqueeze(-1)
conductance = self.phys_encoder(edge_weight)
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
return x + self.alpha * net_flux
def message(self, x_i, x_j, conductance):
delta = x_j - x_i
flux = delta * conductance
flux = flux + self.conductivity_net(flux)
return flux
class DiffusionNet(nn.Module):
def __init__(
self,
input_dim=1,
output_dim=1,
hidden_dim=8,
n_layers=4,
shared_weights=False,
):
super().__init__()
# Encoder: Projects input temperature to hidden feature space
self.enc = nn.Sequential(
spectral_norm(nn.Linear(input_dim, hidden_dim, bias=True)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, hidden_dim, bias=True)),
)
self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
# Scale parameters for conditioning
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# If shared_weights is True, use the same DiffusionLayer multiple times
if shared_weights:
diffusion_layer = DiffusionLayer(hidden_dim)
self.layers = torch.nn.ModuleList(
[diffusion_layer for _ in range(n_layers)]
)
# If shared_weights is False, use separate DiffusionLayers
else:
# Stack of Diffusion Layers
self.layers = torch.nn.ModuleList(
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
)
# Decoder: Projects hidden features back to Temperature space
self.dec = nn.Sequential(
spectral_norm(nn.Linear(hidden_dim, hidden_dim, bias=True)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, output_dim, bias=True)),
nn.Softplus(), # Ensure positive temperature output
)
self.func = torch.nn.GELU()
self.dt_param = nn.Parameter(torch.tensor(1e-2))
@property
def dt(self):
return torch.clamp(self.dt_param, min=1e-5, max=0.5)
def forward(self, x, edge_index, edge_attr, conductivity):
# 1. Global Residual Connection setup
# We save the input to add it back at the very end.
# The network learns the correction (Delta T), not the absolute T.
x_input = x
# 2. Encode
h = self.enc(x) * torch.exp(self.scale_x)
# Scale edge attributes (learnable gating of physical conductivity)
w = edge_attr * torch.exp(self.scale_edge_attr)
# 4. Message Passing (Diffusion Steps)
for layer in self.layers:
# h is updated internally via residual connection in DiffusionLayer
h = layer(h, edge_index, w, conductivity)
h = self.func(h)
# 5. Decode
delta_x = self.dec(h)
# 6. Final Update (Explicit Euler Step)
# T_new = T_old + Correction
return delta_x + x_input * self.dt
# return delta_x