fix module and model + add curriculum callback
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
|
||||
class DiffusionLayer(MessagePassing):
|
||||
@@ -13,28 +14,34 @@ class DiffusionLayer(MessagePassing):
|
||||
channels: int,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(aggr="add", **kwargs)
|
||||
|
||||
self.dt = nn.Parameter(torch.tensor(1e-4))
|
||||
self.conductivity_net = nn.Sequential(
|
||||
nn.Linear(channels, channels, bias=False),
|
||||
spectral_norm(nn.Linear(channels, channels, bias=False)),
|
||||
nn.GELU(),
|
||||
nn.Linear(channels, channels, bias=False),
|
||||
spectral_norm(nn.Linear(channels, channels, bias=False)),
|
||||
)
|
||||
|
||||
self.phys_encoder = nn.Sequential(
|
||||
nn.Linear(1, 8, bias=False),
|
||||
spectral_norm(nn.Linear(1, 8, bias=True)),
|
||||
nn.Tanh(),
|
||||
nn.Linear(8, 1, bias=False),
|
||||
spectral_norm(nn.Linear(8, 1, bias=True)),
|
||||
nn.Softplus(),
|
||||
)
|
||||
|
||||
self.alpha_param = nn.Parameter(torch.tensor(1e-2))
|
||||
|
||||
@property
|
||||
def alpha(self):
|
||||
return torch.clamp(self.alpha_param, min=1e-5, max=1.0)
|
||||
|
||||
def forward(self, x, edge_index, edge_weight, conductivity):
|
||||
edge_weight = edge_weight.unsqueeze(-1)
|
||||
conductance = self.phys_encoder(edge_weight)
|
||||
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
|
||||
return x + ((net_flux) * self.dt)
|
||||
# return (1-self.alpha) * x + self.alpha * net_flux
|
||||
# return net_flux + x
|
||||
return x + self.alpha * net_flux
|
||||
|
||||
def message(self, x_i, x_j, conductance):
|
||||
delta = x_j - x_i
|
||||
@@ -44,15 +51,21 @@ class DiffusionLayer(MessagePassing):
|
||||
|
||||
|
||||
class DiffusionNet(nn.Module):
|
||||
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
hidden_dim=8,
|
||||
n_layers=4,
|
||||
shared_weights=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Encoder: Projects input temperature to hidden feature space
|
||||
self.enc = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim, bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
||||
spectral_norm(nn.Linear(input_dim, hidden_dim, bias=True)),
|
||||
nn.GELU(),
|
||||
spectral_norm(nn.Linear(hidden_dim, hidden_dim, bias=True)),
|
||||
)
|
||||
|
||||
self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
|
||||
@@ -60,27 +73,40 @@ class DiffusionNet(nn.Module):
|
||||
# Scale parameters for conditioning
|
||||
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
|
||||
|
||||
# Stack of Diffusion Layers
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
|
||||
)
|
||||
# If shared_weights is True, use the same DiffusionLayer multiple times
|
||||
if shared_weights:
|
||||
diffusion_layer = DiffusionLayer(hidden_dim)
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[diffusion_layer for _ in range(n_layers)]
|
||||
)
|
||||
# If shared_weights is False, use separate DiffusionLayers
|
||||
else:
|
||||
# Stack of Diffusion Layers
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
|
||||
)
|
||||
|
||||
# Decoder: Projects hidden features back to Temperature space
|
||||
self.dec = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
||||
spectral_norm(nn.Linear(hidden_dim, hidden_dim, bias=True)),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, output_dim, bias=True),
|
||||
spectral_norm(nn.Linear(hidden_dim, output_dim, bias=True)),
|
||||
nn.Softplus(), # Ensure positive temperature output
|
||||
)
|
||||
|
||||
self.func = torch.nn.GELU()
|
||||
|
||||
self.dt_param = nn.Parameter(torch.tensor(1e-2))
|
||||
|
||||
@property
|
||||
def dt(self):
|
||||
return torch.clamp(self.dt_param, min=1e-5, max=0.5)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr, conductivity):
|
||||
# 1. Global Residual Connection setup
|
||||
# We save the input to add it back at the very end.
|
||||
# The network learns the correction (Delta T), not the absolute T.
|
||||
x_input = x
|
||||
|
||||
# 2. Encode
|
||||
h = self.enc(x) * torch.exp(self.scale_x)
|
||||
|
||||
@@ -98,5 +124,4 @@ class DiffusionNet(nn.Module):
|
||||
|
||||
# 6. Final Update (Explicit Euler Step)
|
||||
# T_new = T_old + Correction
|
||||
# return x_input + delta_x
|
||||
return delta_x
|
||||
return delta_x + x_input * self.dt
|
||||
|
||||
Reference in New Issue
Block a user