import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm class DiffusionLayer(MessagePassing): """ Modella: T_new = T_old + dt * Divergenza(Flusso) """ def __init__( self, channels: int, **kwargs, ): super().__init__(aggr="add", **kwargs) self.conductivity_net = nn.Sequential( spectral_norm(nn.Linear(channels, channels, bias=False)), nn.GELU(), spectral_norm(nn.Linear(channels, channels, bias=False)), ) self.phys_encoder = nn.Sequential( spectral_norm(nn.Linear(1, 8, bias=True)), nn.Tanh(), spectral_norm(nn.Linear(8, 1, bias=True)), nn.Softplus(), ) self.alpha_param = nn.Parameter(torch.tensor(1e-2)) @property def alpha(self): return torch.clamp(self.alpha_param, min=1e-7, max=1.0) def forward(self, x, edge_index, edge_weight, conductivity): edge_weight = edge_weight.unsqueeze(-1) conductance = self.phys_encoder(edge_weight) net_flux = self.propagate(edge_index, x=x, conductance=conductance) return x + self.alpha * net_flux def message(self, x_i, x_j, conductance): delta = x_j - x_i flux = delta * conductance flux = flux + self.conductivity_net(flux) return flux class DiffusionNet(nn.Module): def __init__( self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4, shared_weights=False, ): super().__init__() # Encoder: Projects input temperature to hidden feature space self.enc = nn.Sequential( spectral_norm(nn.Linear(input_dim, hidden_dim, bias=True)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim, hidden_dim, bias=True)), ) self.scale_x = nn.Parameter(torch.zeros(hidden_dim)) # Scale parameters for conditioning self.scale_edge_attr = nn.Parameter(torch.zeros(1)) # If shared_weights is True, use the same DiffusionLayer multiple times if shared_weights: diffusion_layer = DiffusionLayer(hidden_dim) self.layers = torch.nn.ModuleList( [diffusion_layer for _ in range(n_layers)] ) # If shared_weights is False, use separate DiffusionLayers else: # Stack of Diffusion Layers self.layers = torch.nn.ModuleList( [DiffusionLayer(hidden_dim) for _ in range(n_layers)] ) # Decoder: Projects hidden features back to Temperature space self.dec = nn.Sequential( spectral_norm(nn.Linear(hidden_dim, hidden_dim, bias=True)), nn.GELU(), spectral_norm(nn.Linear(hidden_dim, output_dim, bias=True)), nn.Softplus(), # Ensure positive temperature output ) self.func = torch.nn.GELU() self.dt_param = nn.Parameter(torch.tensor(1e-2)) @property def dt(self): return torch.clamp(self.dt_param, min=1e-5, max=0.5) def forward(self, x, edge_index, edge_attr, conductivity): # 1. Global Residual Connection setup # We save the input to add it back at the very end. # The network learns the correction (Delta T), not the absolute T. x_input = x # 2. Encode h = self.enc(x) * torch.exp(self.scale_x) # Scale edge attributes (learnable gating of physical conductivity) w = edge_attr * torch.exp(self.scale_edge_attr) # 4. Message Passing (Diffusion Steps) for layer in self.layers: # h is updated internally via residual connection in DiffusionLayer h = layer(h, edge_index, w, conductivity) h = self.func(h) # 5. Decode delta_x = self.dec(h) # 6. Final Update (Explicit Euler Step) # T_new = T_old + Correction return delta_x + x_input * self.dt