fix model

This commit is contained in:
FilippoOlivo
2025-12-01 14:55:13 +01:00
parent c36c59d08d
commit 54bebf7154
5 changed files with 167 additions and 88 deletions

View File

@@ -7,6 +7,7 @@ from matplotlib.tri import Triangulation
from .model.finite_difference import FiniteDifferenceStep from .model.finite_difference import FiniteDifferenceStep
import os import os
def import_class(class_path: str): def import_class(class_path: str):
module_path, class_name = class_path.rsplit(".", 1) # split last dot module_path, class_name = class_path.rsplit(".", 1) # split last dot
module = importlib.import_module(module_path) # import the module module = importlib.import_module(module_path) # import the module
@@ -14,7 +15,7 @@ def import_class(class_path: str):
return cls return cls
def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx): def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
for j in [0, 10, 20, 30]: for j in [0, 10, 20, 30]:
idx = (batch == j).nonzero(as_tuple=True)[0] idx = (batch == j).nonzero(as_tuple=True)[0]
y = y_[idx].detach().cpu() y = y_[idx].detach().cpu()
@@ -49,6 +50,7 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx):
plt.savefig(name, dpi=72) plt.savefig(name, dpi=72)
plt.close() plt.close()
def _plot_losses(losses, batch_idx): def _plot_losses(losses, batch_idx):
folder = f"{batch_idx:02d}_images" folder = f"{batch_idx:02d}_images"
plt.figure() plt.figure()
@@ -74,8 +76,8 @@ class GraphSolver(LightningModule):
super().__init__() super().__init__()
self.model = import_class(model_class_path)(**model_init_args) self.model = import_class(model_class_path)(**model_init_args)
# for param in self.model.parameters(): # for param in self.model.parameters():
# print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param.shape}, Grad: {param.grad}")
# print(f"Param: {param[0]}") # print(f"Param: {param[0]}")
self.loss = loss if loss is not None else torch.nn.MSELoss() self.loss = loss if loss is not None else torch.nn.MSELoss()
self.unrolling_steps = unrolling_steps self.unrolling_steps = unrolling_steps
@@ -101,29 +103,36 @@ class GraphSolver(LightningModule):
return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze()
def _compute_model_steps( def _compute_model_steps(
self, x, edge_index, edge_attr, boundary_mask, boundary_values self,
): x,
out = self.model(x, edge_index, edge_attr) edge_index,
edge_attr,
boundary_mask,
boundary_values,
conductivity,
):
out = self.model(x, edge_index, edge_attr, conductivity)
out[boundary_mask] = boundary_values.unsqueeze(-1) out[boundary_mask] = boundary_values.unsqueeze(-1)
# print(torch.min(out), torch.max(out))
return out return out
def _preprocess_batch(self, batch: Batch): def _preprocess_batch(self, batch: Batch):
x, y, c, edge_index, edge_attr = ( x, y, c, edge_index, edge_attr, nodal_area = (
batch.x, batch.x,
batch.y, batch.y,
batch.c, batch.c,
batch.edge_index, batch.edge_index,
batch.edge_attr, batch.edge_attr,
batch.nodal_area,
) )
edge_attr = 1 / edge_attr edge_attr = 1 / edge_attr
c_ij = self._compute_c_ij(c, edge_index) conductivity = self._compute_c_ij(c, edge_index)
edge_attr = edge_attr * c_ij edge_attr = edge_attr * conductivity
# edge_attr = edge_attr / torch.max(edge_attr) return x, y, edge_index, edge_attr, conductivity
return x, y, edge_index, edge_attr
def training_step(self, batch: Batch): def training_step(self, batch: Batch):
x, y, edge_index, edge_attr = self._preprocess_batch(batch) x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
batch
)
# deg = self._compute_deg(edge_index, edge_attr, x.size(0)) # deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = [] losses = []
# print(x.shape, y.shape) # print(x.shape, y.shape)
@@ -160,12 +169,13 @@ class GraphSolver(LightningModule):
# deg, # deg,
batch.boundary_mask, batch.boundary_mask,
batch.boundary_values, batch.boundary_values,
conductivity,
) )
x = out x = out
# print(out.shape, y[:, i, :].shape) # print(out.shape, y[:, i, :].shape)
losses.append(self.loss(out.flatten(), y[:, i, :].flatten())) losses.append(self.loss(out.flatten(), y[:, i, :].flatten()))
# print(self.model.scale_edge_attr.item()) # print(self.model.scale_edge_attr.item())
loss = torch.stack(losses).mean() loss = torch.stack(losses).mean()
# for param in self.model.parameters(): # for param in self.model.parameters():
# print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param.shape}, Grad: {param.grad}")
@@ -173,26 +183,40 @@ class GraphSolver(LightningModule):
self._log_loss(loss, batch, "train") self._log_loss(loss, batch, "train")
return loss return loss
def validation_step(self, batch: Batch, batch_idx): def validation_step(self, batch: Batch, batch_idx):
x, y, edge_index, edge_attr = self._preprocess_batch(batch) x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
batch
)
# deg = self._compute_deg(edge_index, edge_attr, x.size(0)) # deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = [] losses = []
pos = batch.pos pos = batch.pos
for i in range(self.unrolling_steps): for i in range(self.unrolling_steps):
out = self._compute_model_steps( out = self._compute_model_steps(
# torch.cat([x,pos], dim=-1), # torch.cat([x,pos], dim=-1),
x, x,
edge_index, edge_index,
edge_attr, edge_attr,
# deg, # deg,
batch.boundary_mask, batch.boundary_mask,
batch.boundary_values, batch.boundary_values,
conductivity,
) )
if (batch_idx == 0 and self.current_epoch % 10 == 0 and self.current_epoch > 20): if (
_plot_mesh(batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch) batch_idx == 0
and self.current_epoch % 10 == 0
and self.current_epoch > 0
):
_plot_mesh(
batch.pos,
x,
out,
y[:, i, :],
batch.batch,
i,
self.current_epoch,
)
x = out x = out
losses.append(self.loss(out , y[:, i, :])) losses.append(self.loss(out, y[:, i, :]))
loss = torch.stack(losses).mean() loss = torch.stack(losses).mean()
self._log_loss(loss, batch, "val") self._log_loss(loss, batch, "val")
@@ -202,5 +226,5 @@ class GraphSolver(LightningModule):
pass pass
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=5e-3) optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
return optimizer return optimizer

View File

@@ -6,7 +6,39 @@ from torch_geometric.data import Data
from torch_geometric.loader import DataLoader from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected from torch_geometric.utils import to_undirected
from .mesh_data import MeshData from .mesh_data import MeshData
# from torch.utils.data import Dataset # from torch.utils.data import Dataset
from torch_geometric.utils import scatter
def compute_nodal_area(edge_index, edge_attr, num_nodes):
"""
1. Calculates Area ~ (Min Edge Length)^2
2. Scales by Mean so average cell has size 1.0
"""
row, col = edge_index
dist = edge_attr.squeeze()
# 1. Get 'h' (Closest neighbor distance)
# Using 'min' filters out diagonal connections in the quad mesh
h = scatter(dist, col, dim=0, dim_size=num_nodes, reduce="min")
# 2. Estimate Raw Area
raw_area = h.pow(2)
# 3. Mean Scaling (The Best Normalization)
# This keeps values near 1.0, preserving stability AND physics ratios.
# We detach to ensure no gradients flow here (it's static data).
mean_val = raw_area.mean().detach()
# Result:
# Small cells -> approx 0.1
# Large cells -> approx 5.0
# Average -> 1.0
# nodal_area = (raw_area / mean_val).unsqueeze(-1) + 1e-6
nodal_area = raw_area
return nodal_area.unsqueeze(-1)
class GraphDataModule(LightningDataModule): class GraphDataModule(LightningDataModule):
def __init__( def __init__(
@@ -26,7 +58,11 @@ class GraphDataModule(LightningDataModule):
self.hf_repo = hf_repo self.hf_repo = hf_repo
self.split_name = split_name self.split_name = split_name
self.dataset_dict = {} self.dataset_dict = {}
self.train_dataset, self.val_dataset, self.test_dataset = None, None, None self.train_dataset, self.val_dataset, self.test_dataset = (
None,
None,
None,
)
self.unrolling_steps = start_unrolling_steps self.unrolling_steps = start_unrolling_steps
self.geometry_dict = {} self.geometry_dict = {}
self.train_size = train_size self.train_size = train_size
@@ -85,7 +121,9 @@ class GraphDataModule(LightningDataModule):
conductivity = torch.tensor( conductivity = torch.tensor(
geometry["conductivity"], dtype=torch.float32 geometry["conductivity"], dtype=torch.float32
) )
temperatures = torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:40] temperatures = torch.tensor(
snapshot["temperatures"], dtype=torch.float32
)[:40]
times = torch.tensor(snapshot["times"], dtype=torch.float32) times = torch.tensor(snapshot["times"], dtype=torch.float32)
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
@@ -100,16 +138,19 @@ class GraphDataModule(LightningDataModule):
) )
if self.build_radial_graph: if self.build_radial_graph:
from pina.graph import RadiusGraph # from pina.graph import RadiusGraph
if self.radius is None: # if self.radius is None:
raise ValueError("Radius must be specified for radial graph.") # raise ValueError("Radius must be specified for radial graph.")
edge_index = RadiusGraph.compute_radius_graph( # edge_index = RadiusGraph.compute_radius_graph(
pos, radius=self.radius # pos, radius=self.radius
# )
# from torch_geometric.utils import remove_self_loops
# edge_index, _ = remove_self_loops(edge_index)
raise NotImplementedError(
"Radial graph building not implemented yet."
) )
from torch_geometric.utils import remove_self_loops
edge_index, _ = remove_self_loops(edge_index)
else: else:
edge_index = torch.tensor( edge_index = torch.tensor(
geometry["edge_index"], dtype=torch.int64 geometry["edge_index"], dtype=torch.int64
@@ -117,31 +158,37 @@ class GraphDataModule(LightningDataModule):
edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
boundary_mask, boundary_values = self._compute_boundary_mask( boundary_mask, boundary_values = self._compute_boundary_mask(
bottom_ids, right_ids, top_ids, left_ids, temperatures[0,:] bottom_ids, right_ids, top_ids, left_ids, temperatures[0, :]
) )
edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
nodal_area = compute_nodal_area(edge_index, edge_attr, pos.size(0))
if self.remove_boundary_edges: if self.remove_boundary_edges:
boundary_idx = torch.unique(boundary_mask) boundary_idx = torch.unique(boundary_mask)
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx) edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
edge_index = edge_index[:, edge_index_mask] edge_index = edge_index[:, edge_index_mask]
edge_attr = edge_attr[edge_index_mask]
edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
n_data = temperatures.size(0) - self.unrolling_steps n_data = temperatures.size(0) - self.unrolling_steps
data = [] data = []
for i in range(n_data): for i in range(n_data):
x = temperatures[i, :].unsqueeze(-1) x = temperatures[i, :].unsqueeze(-1)
y = temperatures[i + 1 : i + 1 + self.unrolling_steps, :].unsqueeze(-1).permute(1,0,2) y = (
data.append(MeshData( temperatures[i + 1 : i + 1 + self.unrolling_steps, :]
x=x, .unsqueeze(-1)
y=y, .permute(1, 0, 2)
c=conductivity.unsqueeze(-1), )
edge_index=edge_index, data.append(
pos=pos, MeshData(
edge_attr=edge_attr, x=x,
boundary_mask=boundary_mask, y=y,
boundary_values=boundary_values, c=conductivity.unsqueeze(-1),
)) edge_index=edge_index,
pos=pos,
edge_attr=edge_attr,
boundary_mask=boundary_mask,
boundary_values=boundary_values,
nodal_area=nodal_area,
)
)
return data return data
def setup(self, stage: str = None): def setup(self, stage: str = None):
@@ -207,7 +254,9 @@ class GraphDataModule(LightningDataModule):
) )
def test_dataloader(self): def test_dataloader(self):
ds = self.create_autoregressive_datasets(dataset="test", no_unrolling=True) ds = self.create_autoregressive_datasets(
dataset="test", no_unrolling=True
)
return DataLoader( return DataLoader(
ds, ds,
batch_size=self.batch_size, batch_size=self.batch_size,

View File

@@ -7,6 +7,7 @@ from matplotlib.tri import Triangulation
from .model.finite_difference import FiniteDifferenceStep from .model.finite_difference import FiniteDifferenceStep
import os import os
def import_class(class_path: str): def import_class(class_path: str):
module_path, class_name = class_path.rsplit(".", 1) # split last dot module_path, class_name = class_path.rsplit(".", 1) # split last dot
module = importlib.import_module(module_path) # import the module module = importlib.import_module(module_path) # import the module
@@ -43,6 +44,7 @@ def _plot_mesh(pos, y, y_pred, batch, i, batch_idx):
plt.savefig(name, dpi=72) plt.savefig(name, dpi=72)
plt.close() plt.close()
def _plot_losses(losses, batch_idx): def _plot_losses(losses, batch_idx):
folder = f"{batch_idx:02d}_images" folder = f"{batch_idx:02d}_images"
plt.figure() plt.figure()

View File

@@ -2,37 +2,39 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch_geometric.nn import MessagePassing from torch_geometric.nn import MessagePassing
class DiffusionLayer(MessagePassing): class DiffusionLayer(MessagePassing):
""" """
Modella: T_new = T_old + dt * Divergenza(Flusso) Modella: T_new = T_old + dt * Divergenza(Flusso)
""" """
def __init__( def __init__(
self, self,
channels: int, channels: int,
**kwargs, **kwargs,
): ):
super().__init__(aggr='add', **kwargs) super().__init__(aggr="add", **kwargs)
self.dt = nn.Parameter(torch.tensor(1e-4)) self.dt = nn.Parameter(torch.tensor(1e-4))
self.conductivity_net = nn.Sequential( self.conductivity_net = nn.Sequential(
nn.Linear(channels, channels, bias=False), nn.Linear(channels, channels, bias=False),
nn.GELU(), nn.GELU(),
nn.Linear(channels, channels, bias=False), nn.Linear(channels, channels, bias=False),
) )
self.phys_encoder = nn.Sequential( self.phys_encoder = nn.Sequential(
nn.Linear(1, 8, bias=False), nn.Linear(1, 8, bias=False),
nn.Tanh(), nn.Tanh(),
nn.Linear(8, 1, bias=False), nn.Linear(8, 1, bias=False),
nn.Softplus() nn.Softplus(),
) )
def forward(self, x, edge_index, edge_weight): def forward(self, x, edge_index, edge_weight, conductivity):
edge_weight = edge_weight.unsqueeze(-1) edge_weight = edge_weight.unsqueeze(-1)
conductance = self.phys_encoder(edge_weight) conductance = self.phys_encoder(edge_weight)
net_flux = self.propagate(edge_index, x=x, conductance=conductance) net_flux = self.propagate(edge_index, x=x, conductance=conductance)
return x + (net_flux * self.dt) return x + ((net_flux) * self.dt)
def message(self, x_i, x_j, conductance): def message(self, x_i, x_j, conductance):
delta = x_j - x_i delta = x_j - x_i
@@ -44,7 +46,7 @@ class DiffusionLayer(MessagePassing):
class DiffusionNet(nn.Module): class DiffusionNet(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4): def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4):
super().__init__() super().__init__()
# Encoder: Projects input temperature to hidden feature space # Encoder: Projects input temperature to hidden feature space
self.enc = nn.Sequential( self.enc = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=True), nn.Linear(input_dim, hidden_dim, bias=True),
@@ -57,12 +59,12 @@ class DiffusionNet(nn.Module):
# Scale parameters for conditioning # Scale parameters for conditioning
self.scale_edge_attr = nn.Parameter(torch.zeros(1)) self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# Stack of Diffusion Layers # Stack of Diffusion Layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[DiffusionLayer(hidden_dim) for _ in range(n_layers)] [DiffusionLayer(hidden_dim) for _ in range(n_layers)]
) )
# Decoder: Projects hidden features back to Temperature space # Decoder: Projects hidden features back to Temperature space
self.dec = nn.Sequential( self.dec = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True), nn.Linear(hidden_dim, hidden_dim, bias=True),
@@ -73,28 +75,28 @@ class DiffusionNet(nn.Module):
self.func = torch.nn.GELU() self.func = torch.nn.GELU()
def forward(self, x, edge_index, edge_attr): def forward(self, x, edge_index, edge_attr, conductivity):
# 1. Global Residual Connection setup # 1. Global Residual Connection setup
# We save the input to add it back at the very end. # We save the input to add it back at the very end.
# The network learns the correction (Delta T), not the absolute T. # The network learns the correction (Delta T), not the absolute T.
x_input = x x_input = x
# 2. Encode # 2. Encode
h = self.enc(x) * torch.exp(self.scale_x) h = self.enc(x) * torch.exp(self.scale_x)
# Scale edge attributes (learnable gating of physical conductivity) # Scale edge attributes (learnable gating of physical conductivity)
w = edge_attr * torch.exp(self.scale_edge_attr) w = edge_attr * torch.exp(self.scale_edge_attr)
# 4. Message Passing (Diffusion Steps) # 4. Message Passing (Diffusion Steps)
for layer in self.layers: for layer in self.layers:
# h is updated internally via residual connection in DiffusionLayer # h is updated internally via residual connection in DiffusionLayer
h = layer(h, edge_index, w) h = layer(h, edge_index, w, conductivity)
h = self.func(h) h = self.func(h)
# 5. Decode # 5. Decode
delta_x = self.dec(h) delta_x = self.dec(h)
# 6. Final Update (Explicit Euler Step) # 6. Final Update (Explicit Euler Step)
# T_new = T_old + Correction # T_new = T_old + Correction
# return x_input + delta_x # return x_input + delta_x
return delta_ddx return delta_x

View File

@@ -44,7 +44,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
# def message(self, x_j, edge_weight): # def message(self, x_j, edge_weight):
# return x_j * edge_weight.view(-1, 1) # return x_j * edge_weight.view(-1, 1)
# @staticmethod # @staticmethod
# def normalize(edge_weights, edge_index, num_nodes, dtype=None): # def normalize(edge_weights, edge_index, num_nodes, dtype=None):
# """Symmetrically normalize edge weights.""" # """Symmetrically normalize edge weights."""
@@ -58,7 +58,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
# deg_inv_sqrt = deg.pow(-0.5) # deg_inv_sqrt = deg.pow(-0.5)
# deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0 # deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
# return deg_inv_sqrt[row] * edge_weights * deg_inv_sqrt[col] # return deg_inv_sqrt[row] * edge_weights * deg_inv_sqrt[col]
# class CorrectionNet(nn.Module): # class CorrectionNet(nn.Module):
# def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8): # def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8):
@@ -89,7 +89,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
# super().__init__() # super().__init__()
# layers = [] # layers = []
# func = torch.nn.ReLU # func = torch.nn.ReLU
# self.network = nn.Sequential( # self.network = nn.Sequential(
# nn.Linear(input_dim, hidden_dim), # nn.Linear(input_dim, hidden_dim),
# func(), # func(),
@@ -112,30 +112,32 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
# import torch.nn as nn # import torch.nn as nn
# from torch_geometric.nn import MessagePassing # from torch_geometric.nn import MessagePassing
class DiffusionLayer(MessagePassing): class DiffusionLayer(MessagePassing):
""" """
Modella: T_new = T_old + dt * Divergenza(Flusso) Modella: T_new = T_old + dt * Divergenza(Flusso)
""" """
def __init__( def __init__(
self, self,
channels: int, channels: int,
**kwargs, **kwargs,
): ):
super().__init__(aggr='add', **kwargs) super().__init__(aggr="add", **kwargs)
self.dt = nn.Parameter(torch.tensor(1e-4)) self.dt = nn.Parameter(torch.tensor(1e-4))
self.conductivity_net = nn.Sequential( self.conductivity_net = nn.Sequential(
nn.Linear(channels, channels, bias=False), nn.Linear(channels, channels, bias=False),
nn.GELU(), nn.GELU(),
nn.Linear(channels, channels, bias=False), nn.Linear(channels, channels, bias=False),
) )
self.phys_encoder = nn.Sequential( self.phys_encoder = nn.Sequential(
nn.Linear(1, 8, bias=False), nn.Linear(1, 8, bias=False),
nn.Tanh(), nn.Tanh(),
nn.Linear(8, 1, bias=False), nn.Linear(8, 1, bias=False),
nn.Softplus() nn.Softplus(),
) )
def forward(self, x, edge_index, edge_weight): def forward(self, x, edge_index, edge_weight):
@@ -154,7 +156,7 @@ class DiffusionLayer(MessagePassing):
class CorrectionNet(nn.Module): class CorrectionNet(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=32, n_layers=4): def __init__(self, input_dim=1, output_dim=1, hidden_dim=32, n_layers=4):
super().__init__() super().__init__()
# Encoder: Projects input temperature to hidden feature space # Encoder: Projects input temperature to hidden feature space
self.enc = nn.Sequential( self.enc = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=True), nn.Linear(input_dim, hidden_dim, bias=True),
@@ -167,12 +169,12 @@ class CorrectionNet(nn.Module):
# Scale parameters for conditioning # Scale parameters for conditioning
self.scale_edge_attr = nn.Parameter(torch.zeros(1)) self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# Stack of Diffusion Layers # Stack of Diffusion Layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[DiffusionLayer(hidden_dim) for _ in range(n_layers)] [DiffusionLayer(hidden_dim) for _ in range(n_layers)]
) )
# Decoder: Projects hidden features back to Temperature space # Decoder: Projects hidden features back to Temperature space
self.dec = nn.Sequential( self.dec = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True), nn.Linear(hidden_dim, hidden_dim, bias=True),
@@ -185,26 +187,26 @@ class CorrectionNet(nn.Module):
def forward(self, x, edge_index, edge_attr): def forward(self, x, edge_index, edge_attr):
# 1. Global Residual Connection setup # 1. Global Residual Connection setup
# We save the input to add it back at the very end. # We save the input to add it back at the very end.
# The network learns the correction (Delta T), not the absolute T. # The network learns the correction (Delta T), not the absolute T.
x_input = x x_input = x
# 2. Encode # 2. Encode
h = self.enc(x) * torch.exp(self.scale_x) h = self.enc(x) * torch.exp(self.scale_x)
# Scale edge attributes (learnable gating of physical conductivity) # Scale edge attributes (learnable gating of physical conductivity)
w = edge_attr * torch.exp(self.scale_edge_attr) w = edge_attr * torch.exp(self.scale_edge_attr)
# 4. Message Passing (Diffusion Steps) # 4. Message Passing (Diffusion Steps)
for layer in self.layers: for layer in self.layers:
# h is updated internally via residual connection in DiffusionLayer # h is updated internally via residual connection in DiffusionLayer
h = layer(h, edge_index, w) h = layer(h, edge_index, w)
h = self.func(h) h = self.func(h)
# 5. Decode # 5. Decode
delta_x = self.dec(h) delta_x = self.dec(h)
# 6. Final Update (Explicit Euler Step) # 6. Final Update (Explicit Euler Step)
# T_new = T_old + Correction # T_new = T_old + Correction
# return x_input + delta_x # return x_input + delta_x
return delta_x return delta_x