import torch import torch.nn as nn from torch_geometric.nn import MessagePassing class DiffusionLayer(MessagePassing): """ Modella: T_new = T_old + dt * Divergenza(Flusso) """ def __init__( self, channels: int, **kwargs, ): super().__init__(aggr='add', **kwargs) self.dt = nn.Parameter(torch.tensor(1e-4)) self.conductivity_net = nn.Sequential( nn.Linear(channels, channels, bias=False), nn.GELU(), nn.Linear(channels, channels, bias=False), ) self.phys_encoder = nn.Sequential( nn.Linear(1, 8, bias=False), nn.Tanh(), nn.Linear(8, 1, bias=False), nn.Softplus() ) def forward(self, x, edge_index, edge_weight): edge_weight = edge_weight.unsqueeze(-1) conductance = self.phys_encoder(edge_weight) net_flux = self.propagate(edge_index, x=x, conductance=conductance) return x + (net_flux * self.dt) def message(self, x_i, x_j, conductance): delta = x_j - x_i flux = delta * conductance flux = flux + self.conductivity_net(flux) return flux class DiffusionNet(nn.Module): def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4): super().__init__() # Encoder: Projects input temperature to hidden feature space self.enc = nn.Sequential( nn.Linear(input_dim, hidden_dim, bias=True), nn.GELU(), nn.Linear(hidden_dim, hidden_dim, bias=True), nn.GELU(), ) self.scale_x = nn.Parameter(torch.zeros(hidden_dim)) # Scale parameters for conditioning self.scale_edge_attr = nn.Parameter(torch.zeros(1)) # Stack of Diffusion Layers self.layers = torch.nn.ModuleList( [DiffusionLayer(hidden_dim) for _ in range(n_layers)] ) # Decoder: Projects hidden features back to Temperature space self.dec = nn.Sequential( nn.Linear(hidden_dim, hidden_dim, bias=True), nn.GELU(), nn.Linear(hidden_dim, output_dim, bias=True), nn.Softplus(), # Ensure positive temperature output ) self.func = torch.nn.GELU() def forward(self, x, edge_index, edge_attr): # 1. Global Residual Connection setup # We save the input to add it back at the very end. # The network learns the correction (Delta T), not the absolute T. x_input = x # 2. Encode h = self.enc(x) * torch.exp(self.scale_x) # Scale edge attributes (learnable gating of physical conductivity) w = edge_attr * torch.exp(self.scale_edge_attr) # 4. Message Passing (Diffusion Steps) for layer in self.layers: # h is updated internally via residual connection in DiffusionLayer h = layer(h, edge_index, w) h = self.func(h) # 5. Decode delta_x = self.dec(h) # 6. Final Update (Explicit Euler Step) # T_new = T_old + Correction # return x_input + delta_x return delta_ddx