100 lines
3.2 KiB
Python
100 lines
3.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch_geometric.nn import MessagePassing
|
|
|
|
class DiffusionLayer(MessagePassing):
|
|
"""
|
|
Modella: T_new = T_old + dt * Divergenza(Flusso)
|
|
"""
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
**kwargs,
|
|
):
|
|
|
|
super().__init__(aggr='add', **kwargs)
|
|
|
|
self.dt = nn.Parameter(torch.tensor(1e-4))
|
|
self.conductivity_net = nn.Sequential(
|
|
nn.Linear(channels, channels, bias=False),
|
|
nn.GELU(),
|
|
nn.Linear(channels, channels, bias=False),
|
|
)
|
|
|
|
self.phys_encoder = nn.Sequential(
|
|
nn.Linear(1, 8, bias=False),
|
|
nn.Tanh(),
|
|
nn.Linear(8, 1, bias=False),
|
|
nn.Softplus()
|
|
)
|
|
|
|
def forward(self, x, edge_index, edge_weight):
|
|
edge_weight = edge_weight.unsqueeze(-1)
|
|
conductance = self.phys_encoder(edge_weight)
|
|
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
|
|
return x + (net_flux * self.dt)
|
|
|
|
def message(self, x_i, x_j, conductance):
|
|
delta = x_j - x_i
|
|
flux = delta * conductance
|
|
flux = flux + self.conductivity_net(flux)
|
|
return flux
|
|
|
|
|
|
class DiffusionNet(nn.Module):
|
|
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4):
|
|
super().__init__()
|
|
|
|
# Encoder: Projects input temperature to hidden feature space
|
|
self.enc = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim, bias=True),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
|
nn.GELU(),
|
|
)
|
|
|
|
self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
|
|
|
|
# Scale parameters for conditioning
|
|
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
|
|
|
|
# Stack of Diffusion Layers
|
|
self.layers = torch.nn.ModuleList(
|
|
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
|
|
)
|
|
|
|
# Decoder: Projects hidden features back to Temperature space
|
|
self.dec = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_dim, output_dim, bias=True),
|
|
nn.Softplus(), # Ensure positive temperature output
|
|
)
|
|
|
|
self.func = torch.nn.GELU()
|
|
|
|
def forward(self, x, edge_index, edge_attr):
|
|
# 1. Global Residual Connection setup
|
|
# We save the input to add it back at the very end.
|
|
# The network learns the correction (Delta T), not the absolute T.
|
|
x_input = x
|
|
|
|
# 2. Encode
|
|
h = self.enc(x) * torch.exp(self.scale_x)
|
|
|
|
# Scale edge attributes (learnable gating of physical conductivity)
|
|
w = edge_attr * torch.exp(self.scale_edge_attr)
|
|
|
|
# 4. Message Passing (Diffusion Steps)
|
|
for layer in self.layers:
|
|
# h is updated internally via residual connection in DiffusionLayer
|
|
h = layer(h, edge_index, w)
|
|
h = self.func(h)
|
|
|
|
# 5. Decode
|
|
delta_x = self.dec(h)
|
|
|
|
# 6. Final Update (Explicit Euler Step)
|
|
# T_new = T_old + Correction
|
|
# return x_input + delta_x
|
|
return delta_ddx |