fix model
This commit is contained in:
@@ -44,7 +44,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
|
||||
|
||||
# def message(self, x_j, edge_weight):
|
||||
# return x_j * edge_weight.view(-1, 1)
|
||||
|
||||
|
||||
# @staticmethod
|
||||
# def normalize(edge_weights, edge_index, num_nodes, dtype=None):
|
||||
# """Symmetrically normalize edge weights."""
|
||||
@@ -58,7 +58,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
|
||||
# deg_inv_sqrt = deg.pow(-0.5)
|
||||
# deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
|
||||
# return deg_inv_sqrt[row] * edge_weights * deg_inv_sqrt[col]
|
||||
|
||||
|
||||
|
||||
# class CorrectionNet(nn.Module):
|
||||
# def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8):
|
||||
@@ -89,7 +89,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
|
||||
# super().__init__()
|
||||
# layers = []
|
||||
# func = torch.nn.ReLU
|
||||
|
||||
|
||||
# self.network = nn.Sequential(
|
||||
# nn.Linear(input_dim, hidden_dim),
|
||||
# func(),
|
||||
@@ -112,30 +112,32 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
|
||||
# import torch.nn as nn
|
||||
# from torch_geometric.nn import MessagePassing
|
||||
|
||||
|
||||
class DiffusionLayer(MessagePassing):
|
||||
"""
|
||||
Modella: T_new = T_old + dt * Divergenza(Flusso)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(aggr='add', **kwargs)
|
||||
|
||||
|
||||
super().__init__(aggr="add", **kwargs)
|
||||
|
||||
self.dt = nn.Parameter(torch.tensor(1e-4))
|
||||
self.conductivity_net = nn.Sequential(
|
||||
nn.Linear(channels, channels, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(channels, channels, bias=False),
|
||||
)
|
||||
|
||||
|
||||
self.phys_encoder = nn.Sequential(
|
||||
nn.Linear(1, 8, bias=False),
|
||||
nn.Tanh(),
|
||||
nn.Linear(8, 1, bias=False),
|
||||
nn.Softplus()
|
||||
nn.Softplus(),
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index, edge_weight):
|
||||
@@ -154,7 +156,7 @@ class DiffusionLayer(MessagePassing):
|
||||
class CorrectionNet(nn.Module):
|
||||
def __init__(self, input_dim=1, output_dim=1, hidden_dim=32, n_layers=4):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Encoder: Projects input temperature to hidden feature space
|
||||
self.enc = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim, bias=True),
|
||||
@@ -167,12 +169,12 @@ class CorrectionNet(nn.Module):
|
||||
|
||||
# Scale parameters for conditioning
|
||||
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
|
||||
|
||||
|
||||
# Stack of Diffusion Layers
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
|
||||
)
|
||||
|
||||
|
||||
# Decoder: Projects hidden features back to Temperature space
|
||||
self.dec = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
||||
@@ -185,26 +187,26 @@ class CorrectionNet(nn.Module):
|
||||
|
||||
def forward(self, x, edge_index, edge_attr):
|
||||
# 1. Global Residual Connection setup
|
||||
# We save the input to add it back at the very end.
|
||||
# We save the input to add it back at the very end.
|
||||
# The network learns the correction (Delta T), not the absolute T.
|
||||
x_input = x
|
||||
x_input = x
|
||||
|
||||
# 2. Encode
|
||||
h = self.enc(x) * torch.exp(self.scale_x)
|
||||
|
||||
|
||||
# Scale edge attributes (learnable gating of physical conductivity)
|
||||
w = edge_attr * torch.exp(self.scale_edge_attr)
|
||||
|
||||
# 4. Message Passing (Diffusion Steps)
|
||||
for layer in self.layers:
|
||||
# h is updated internally via residual connection in DiffusionLayer
|
||||
h = layer(h, edge_index, w)
|
||||
h = layer(h, edge_index, w)
|
||||
h = self.func(h)
|
||||
|
||||
# 5. Decode
|
||||
delta_x = self.dec(h)
|
||||
|
||||
|
||||
# 6. Final Update (Explicit Euler Step)
|
||||
# T_new = T_old + Correction
|
||||
# return x_input + delta_x
|
||||
return delta_x
|
||||
return delta_x
|
||||
|
||||
Reference in New Issue
Block a user