import torch import torch.nn as nn from torch_geometric.nn import MessagePassing from torch.nn.utils import spectral_norm from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv # class GCNConvLayer(MessagePassing): # def __init__( # self, # in_channels, # out_channels, # aggr: str = 'mean', # bias: bool = True, # **kwargs, # ): # super().__init__(aggr=aggr, **kwargs) # self.in_channels = in_channels # self.out_channels = out_channels # if isinstance(in_channels, int): # in_channels = (in_channels, in_channels) # self.lin_rel = nn.Linear(in_channels[0], out_channels, bias=bias) # self.lin_root = nn.Linear(in_channels[1], out_channels, bias=False) # self.reset_parameters() # def reset_parameters(self): # super().reset_parameters() # self.lin_rel.reset_parameters() # self.lin_root.reset_parameters() # def forward(self, x, edge_index, # edge_weight = None, size = None): # edge_weight = self.normalize(edge_weight, edge_index, x.size(0), dtype=x.dtype) # out = self.propagate(edge_index, x=x, edge_weight=edge_weight, # size=size) # out = self.lin_rel(out) # out = out + self.lin_root(x) # return out # def message(self, x_j, edge_weight): # return x_j * edge_weight.view(-1, 1) # @staticmethod # def normalize(edge_weights, edge_index, num_nodes, dtype=None): # """Symmetrically normalize edge weights.""" # if dtype is None: # dtype = edge_weights.dtype # device = edge_index.device # row, col = edge_index # deg = torch.zeros(num_nodes, device=device, dtype=dtype) # deg = deg.scatter_add(0, row, edge_weights) # deg_inv_sqrt = deg.pow(-0.5) # deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0 # return deg_inv_sqrt[row] * edge_weights * deg_inv_sqrt[col] # class CorrectionNet(nn.Module): # def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8): # super().__init__() # self.enc = nn.Linear(input_dim, hidden_dim, bias=True), # self.scale_x = nn.Parameter(torch.zeros(hidden_dim)) # self.scale_edge_attr = nn.Parameter(torch.zeros(1)) # self.layers = torch.nn.ModuleList( # [GCNConv(hidden_dim, hidden_dim, aggr="mean") for _ in range(n_layers)] # ) # self.dec = nn.Linear(hidden_dim, output_dim, bias=True), # self.func = torch.nn.GELU() # def forward(self, x, edge_index, edge_attr,): # h = self.enc(x) # * torch.exp(self.scale_x) # edge_attr = edge_attr # * torch.exp(self.scale_edge_attr) # h = self.func(h) # for l in self.layers: # h = l(h, edge_index, edge_attr) # h = self.func(h) # out = self.dec(h) # return out # class MLPNet(nn.Module): # def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=1): # super().__init__() # layers = [] # func = torch.nn.ReLU # self.network = nn.Sequential( # nn.Linear(input_dim, hidden_dim), # func(), # nn.Linear(hidden_dim, hidden_dim), # func(), # nn.Linear(hidden_dim, hidden_dim), # func(), # nn.Linear(hidden_dim, output_dim), # ) # def forward(self, x, edge_index=None, edge_attr=None): # return self.network(x) # import torch # import torch.nn as nn # from torch_geometric.nn import MessagePassing # import torch # import torch.nn as nn # from torch_geometric.nn import MessagePassing class DiffusionLayer(MessagePassing): """ Modella: T_new = T_old + dt * Divergenza(Flusso) """ def __init__( self, channels: int, **kwargs, ): super().__init__(aggr="add", **kwargs) self.dt = nn.Parameter(torch.tensor(1e-4)) self.conductivity_net = nn.Sequential( nn.Linear(channels, channels, bias=False), nn.GELU(), nn.Linear(channels, channels, bias=False), ) self.phys_encoder = nn.Sequential( nn.Linear(1, 8, bias=False), nn.Tanh(), nn.Linear(8, 1, bias=False), nn.Softplus(), ) def forward(self, x, edge_index, edge_weight): edge_weight = edge_weight.unsqueeze(-1) conductance = self.phys_encoder(edge_weight) net_flux = self.propagate(edge_index, x=x, conductance=conductance) return x + (net_flux * self.dt) def message(self, x_i, x_j, conductance): delta = x_j - x_i flux = delta * conductance flux = flux + self.conductivity_net(flux) return flux class CorrectionNet(nn.Module): def __init__(self, input_dim=1, output_dim=1, hidden_dim=32, n_layers=4): super().__init__() # Encoder: Projects input temperature to hidden feature space self.enc = nn.Sequential( nn.Linear(input_dim, hidden_dim, bias=True), nn.GELU(), nn.Linear(hidden_dim, hidden_dim, bias=True), nn.GELU(), ) self.scale_x = nn.Parameter(torch.zeros(hidden_dim)) # Scale parameters for conditioning self.scale_edge_attr = nn.Parameter(torch.zeros(1)) # Stack of Diffusion Layers self.layers = torch.nn.ModuleList( [DiffusionLayer(hidden_dim) for _ in range(n_layers)] ) # Decoder: Projects hidden features back to Temperature space self.dec = nn.Sequential( nn.Linear(hidden_dim, hidden_dim, bias=True), nn.GELU(), nn.Linear(hidden_dim, output_dim, bias=True), nn.Softplus(), # Ensure positive temperature output ) self.func = torch.nn.GELU() def forward(self, x, edge_index, edge_attr): # 1. Global Residual Connection setup # We save the input to add it back at the very end. # The network learns the correction (Delta T), not the absolute T. x_input = x # 2. Encode h = self.enc(x) * torch.exp(self.scale_x) # Scale edge attributes (learnable gating of physical conductivity) w = edge_attr * torch.exp(self.scale_edge_attr) # 4. Message Passing (Diffusion Steps) for layer in self.layers: # h is updated internally via residual connection in DiffusionLayer h = layer(h, edge_index, w) h = self.func(h) # 5. Decode delta_x = self.dec(h) # 6. Final Update (Explicit Euler Step) # T_new = T_old + Correction # return x_input + delta_x return delta_x