add model and fix module and datamodule

This commit is contained in:
FilippoOlivo
2025-12-01 10:06:07 +01:00
parent 88bc5c05e4
commit c36c59d08d
4 changed files with 359 additions and 212 deletions

View File

@@ -0,0 +1,100 @@
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
class DiffusionLayer(MessagePassing):
"""
Modella: T_new = T_old + dt * Divergenza(Flusso)
"""
def __init__(
self,
channels: int,
**kwargs,
):
super().__init__(aggr='add', **kwargs)
self.dt = nn.Parameter(torch.tensor(1e-4))
self.conductivity_net = nn.Sequential(
nn.Linear(channels, channels, bias=False),
nn.GELU(),
nn.Linear(channels, channels, bias=False),
)
self.phys_encoder = nn.Sequential(
nn.Linear(1, 8, bias=False),
nn.Tanh(),
nn.Linear(8, 1, bias=False),
nn.Softplus()
)
def forward(self, x, edge_index, edge_weight):
edge_weight = edge_weight.unsqueeze(-1)
conductance = self.phys_encoder(edge_weight)
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
return x + (net_flux * self.dt)
def message(self, x_i, x_j, conductance):
delta = x_j - x_i
flux = delta * conductance
flux = flux + self.conductivity_net(flux)
return flux
class DiffusionNet(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4):
super().__init__()
# Encoder: Projects input temperature to hidden feature space
self.enc = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.GELU(),
)
self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
# Scale parameters for conditioning
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# Stack of Diffusion Layers
self.layers = torch.nn.ModuleList(
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
)
# Decoder: Projects hidden features back to Temperature space
self.dec = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Linear(hidden_dim, output_dim, bias=True),
nn.Softplus(), # Ensure positive temperature output
)
self.func = torch.nn.GELU()
def forward(self, x, edge_index, edge_attr):
# 1. Global Residual Connection setup
# We save the input to add it back at the very end.
# The network learns the correction (Delta T), not the absolute T.
x_input = x
# 2. Encode
h = self.enc(x) * torch.exp(self.scale_x)
# Scale edge attributes (learnable gating of physical conductivity)
w = edge_attr * torch.exp(self.scale_edge_attr)
# 4. Message Passing (Diffusion Steps)
for layer in self.layers:
# h is updated internally via residual connection in DiffusionLayer
h = layer(h, edge_index, w)
h = self.func(h)
# 5. Decode
delta_x = self.dec(h)
# 6. Final Update (Explicit Euler Step)
# T_new = T_old + Correction
# return x_input + delta_x
return delta_ddx

View File

@@ -2,68 +2,209 @@ import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch.nn.utils import spectral_norm
from torch_geometric.nn.conv import GCNConv
from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
class GCNConvLayer(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr="add")
self.lin_l = nn.Linear(in_channels, out_channels, bias=True)
# self.lin_r = spectral_norm(nn.Linear(in_channels, out_channels, bias=False))
# class GCNConvLayer(MessagePassing):
# def __init__(
# self,
# in_channels,
# out_channels,
# aggr: str = 'mean',
# bias: bool = True,
# **kwargs,
# ):
# super().__init__(aggr=aggr, **kwargs)
def forward(self, x, edge_index, edge_attr, deg):
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg)
out = self.lin_l(out)
return out
# self.in_channels = in_channels
# self.out_channels = out_channels
def message(self, x_j, edge_attr):
return x_j * edge_attr.view(-1, 1)
# if isinstance(in_channels, int):
# in_channels = (in_channels, in_channels)
def aggregate(self, inputs, index, deg):
"""
TODO: add docstring.
"""
out = super().aggregate(inputs, index)
deg = deg + 1e-7
return out / deg.view(-1, 1)
# self.lin_rel = nn.Linear(in_channels[0], out_channels, bias=bias)
# self.lin_root = nn.Linear(in_channels[1], out_channels, bias=False)
# self.reset_parameters()
# def reset_parameters(self):
# super().reset_parameters()
# self.lin_rel.reset_parameters()
# self.lin_root.reset_parameters()
# def forward(self, x, edge_index,
# edge_weight = None, size = None):
# edge_weight = self.normalize(edge_weight, edge_index, x.size(0), dtype=x.dtype)
# out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
# size=size)
# out = self.lin_rel(out)
# out = out + self.lin_root(x)
# return out
# def message(self, x_j, edge_weight):
# return x_j * edge_weight.view(-1, 1)
# @staticmethod
# def normalize(edge_weights, edge_index, num_nodes, dtype=None):
# """Symmetrically normalize edge weights."""
# if dtype is None:
# dtype = edge_weights.dtype
# device = edge_index.device
# row, col = edge_index
# deg = torch.zeros(num_nodes, device=device, dtype=dtype)
# deg = deg.scatter_add(0, row, edge_weights)
# deg_inv_sqrt = deg.pow(-0.5)
# deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
# return deg_inv_sqrt[row] * edge_weights * deg_inv_sqrt[col]
# class CorrectionNet(nn.Module):
# def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8):
# super().__init__()
# self.enc = nn.Linear(input_dim, hidden_dim, bias=True),
# self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
# self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# self.layers = torch.nn.ModuleList(
# [GCNConv(hidden_dim, hidden_dim, aggr="mean") for _ in range(n_layers)]
# )
# self.dec = nn.Linear(hidden_dim, output_dim, bias=True),
# self.func = torch.nn.GELU()
# def forward(self, x, edge_index, edge_attr,):
# h = self.enc(x) # * torch.exp(self.scale_x)
# edge_attr = edge_attr # * torch.exp(self.scale_edge_attr)
# h = self.func(h)
# for l in self.layers:
# h = l(h, edge_index, edge_attr)
# h = self.func(h)
# out = self.dec(h)
# return out
# class MLPNet(nn.Module):
# def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=1):
# super().__init__()
# layers = []
# func = torch.nn.ReLU
# self.network = nn.Sequential(
# nn.Linear(input_dim, hidden_dim),
# func(),
# nn.Linear(hidden_dim, hidden_dim),
# func(),
# nn.Linear(hidden_dim, hidden_dim),
# func(),
# nn.Linear(hidden_dim, output_dim),
# )
# def forward(self, x, edge_index=None, edge_attr=None):
# return self.network(x)
# import torch
# import torch.nn as nn
# from torch_geometric.nn import MessagePassing
# import torch
# import torch.nn as nn
# from torch_geometric.nn import MessagePassing
class DiffusionLayer(MessagePassing):
"""
Modella: T_new = T_old + dt * Divergenza(Flusso)
"""
def __init__(
self,
channels: int,
**kwargs,
):
super().__init__(aggr='add', **kwargs)
self.dt = nn.Parameter(torch.tensor(1e-4))
self.conductivity_net = nn.Sequential(
nn.Linear(channels, channels, bias=False),
nn.GELU(),
nn.Linear(channels, channels, bias=False),
)
self.phys_encoder = nn.Sequential(
nn.Linear(1, 8, bias=False),
nn.Tanh(),
nn.Linear(8, 1, bias=False),
nn.Softplus()
)
def forward(self, x, edge_index, edge_weight):
edge_weight = edge_weight.unsqueeze(-1)
conductance = self.phys_encoder(edge_weight)
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
return x + (net_flux * self.dt)
def message(self, x_i, x_j, conductance):
delta = x_j - x_i
flux = delta * conductance
flux = flux + self.conductivity_net(flux)
return flux
class CorrectionNet(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=32, n_layers=4):
super().__init__()
self.enc = nn.Linear(input_dim, hidden_dim, bias=False)
# self.layers = n_layers
# self.l = GCNConv(hidden_dim, hidden_dim, aggr="mean")
self.layers = torch.nn.ModuleList(
[GCNConv(hidden_dim, hidden_dim, aggr="mean", bias=False) for _ in range(n_layers)]
)
self.dec = nn.Linear(hidden_dim, output_dim)
def forward(self, x, edge_index, edge_attr,):
h = self.enc(x)
# h = self.relu(h)
for l in self.layers:
# print(f"Forward pass layer {_}")
h = l(h, edge_index, edge_attr)
# h = self.relu(h)
out = self.dec(h)
return out
class MLPNet(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=1):
super().__init__()
layers = []
func = torch.nn.ReLU
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
func(),
nn.Linear(hidden_dim, hidden_dim),
func(),
nn.Linear(hidden_dim, hidden_dim),
func(),
nn.Linear(hidden_dim, output_dim),
# Encoder: Projects input temperature to hidden feature space
self.enc = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.GELU(),
)
def forward(self, x, edge_index=None, edge_attr=None):
return self.network(x)
self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
# Scale parameters for conditioning
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# Stack of Diffusion Layers
self.layers = torch.nn.ModuleList(
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
)
# Decoder: Projects hidden features back to Temperature space
self.dec = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Linear(hidden_dim, output_dim, bias=True),
nn.Softplus(), # Ensure positive temperature output
)
self.func = torch.nn.GELU()
def forward(self, x, edge_index, edge_attr):
# 1. Global Residual Connection setup
# We save the input to add it back at the very end.
# The network learns the correction (Delta T), not the absolute T.
x_input = x
# 2. Encode
h = self.enc(x) * torch.exp(self.scale_x)
# Scale edge attributes (learnable gating of physical conductivity)
w = edge_attr * torch.exp(self.scale_edge_attr)
# 4. Message Passing (Diffusion Steps)
for layer in self.layers:
# h is updated internally via residual connection in DiffusionLayer
h = layer(h, edge_index, w)
h = self.func(h)
# 5. Decode
delta_x = self.dec(h)
# 6. Final Update (Explicit Euler Step)
# T_new = T_old + Correction
# return x_input + delta_x
return delta_x