add model and fix module and datamodule

This commit is contained in:
FilippoOlivo
2025-12-01 10:06:07 +01:00
parent 88bc5c05e4
commit c36c59d08d
4 changed files with 359 additions and 212 deletions

View File

@@ -14,34 +14,38 @@ def import_class(class_path: str):
return cls return cls
def _plot_mesh(pos, y, y_pred, y_true ,batch, i, batch_idx): def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx):
for j in [0, 10, 20, 30]:
idx = batch == 0 idx = (batch == j).nonzero(as_tuple=True)[0]
y = y[idx].detach().cpu() y = y_[idx].detach().cpu()
y_pred = y_pred[idx].detach().cpu() y_pred = y_pred_[idx].detach().cpu()
pos = pos[idx].detach().cpu() pos = pos_[idx].detach().cpu()
y_true = y_true[idx].detach().cpu() y_true = y_true_[idx].detach().cpu()
# print(torch.max(y_true), torch.min(y_true)) y_true = torch.clamp(y_true, min=0)
folder = f"{batch_idx:02d}_images" folder = f"{j:02d}_images"
if os.path.exists(folder) is False: if os.path.exists(folder) is False:
os.makedirs(folder) os.makedirs(folder)
pos = pos.detach().cpu() pos = pos.detach().cpu()
tria = Triangulation(pos[:, 0], pos[:, 1]) tria = Triangulation(pos[:, 0], pos[:, 1])
plt.figure(figsize=(18, 5)) plt.figure(figsize=(24, 5))
plt.subplot(1, 3, 1) plt.subplot(1, 4, 1)
plt.tricontourf(tria, y.squeeze().numpy(), levels=14) plt.tricontourf(tria, y.squeeze().numpy(), levels=100)
plt.colorbar() plt.colorbar()
plt.title("Step t-1") plt.title("Step t-1")
plt.subplot(1, 3, 2) plt.subplot(1, 4, 2)
plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100)
plt.colorbar() plt.colorbar()
plt.title("Step t Predicted") plt.title("Step t Predicted")
plt.subplot(1, 3, 3) plt.subplot(1, 4, 3)
plt.tricontourf(tria, y_true.squeeze().numpy(), levels=14) plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100)
plt.colorbar() plt.colorbar()
plt.title("t True") plt.title("t True")
plt.subplot(1, 4, 4)
plt.tricontourf(tria, (y_true - y_pred).squeeze().numpy(), levels=100)
plt.colorbar()
plt.title("Error")
plt.suptitle("GNO", fontsize=16) plt.suptitle("GNO", fontsize=16)
name = f"{folder}/graph_iter_{i:04d}.png" name = f"{folder}/{j:04d}_graph_iter_{i:04d}.png"
plt.savefig(name, dpi=72) plt.savefig(name, dpi=72)
plt.close() plt.close()
@@ -65,33 +69,15 @@ class GraphSolver(LightningModule):
model_class_path: str, model_class_path: str,
model_init_args: dict = {}, model_init_args: dict = {},
loss: torch.nn.Module = None, loss: torch.nn.Module = None,
start_unrolling_steps: int = 1, unrolling_steps: int = 1,
increase_every: int = 20,
increase_rate: float = 2,
max_unrolling_steps: int = 100,
max_inference_iters: int = 1000,
inner_steps: int = 16,
): ):
super().__init__() super().__init__()
self.model = import_class(model_class_path)(**model_init_args) self.model = import_class(model_class_path)(**model_init_args)
# for param in self.model.parameters(): # for param in self.model.parameters():
# print(f"Param: {param.shape}, Grad: {param.grad}") # print(f"Param: {param.shape}, Grad: {param.grad}")
# print(f"Param: {param[0]}") # print(f"Param: {param[0]}")
self.fd_net = FiniteDifferenceStep()
self.loss = loss if loss is not None else torch.nn.MSELoss() self.loss = loss if loss is not None else torch.nn.MSELoss()
self.start_unrolling = start_unrolling_steps self.unrolling_steps = unrolling_steps
self.current_unrolling_steps = self.start_unrolling
self.increase_every = increase_every
self.increase_rate = increase_rate
self.max_unrolling_steps = max_unrolling_steps
self.max_inference_iters = max_inference_iters
self.threshold = 1e-4
self.inner_steps = inner_steps
def _compute_deg(self, edge_index, edge_attr, num_nodes):
deg = torch.zeros(num_nodes, device=edge_index.device)
deg = deg.scatter_add(0, edge_index[1], edge_attr)
return deg + 1e-7
def _compute_loss(self, x, y): def _compute_loss(self, x, y):
return self.loss(x, y) return self.loss(x, y)
@@ -100,7 +86,7 @@ class GraphSolver(LightningModule):
self.log( self.log(
f"{stage}/loss", f"{stage}/loss",
loss, loss,
on_step=True, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
batch_size=int(batch.num_graphs), batch_size=int(batch.num_graphs),
@@ -117,18 +103,11 @@ class GraphSolver(LightningModule):
def _compute_model_steps( def _compute_model_steps(
self, x, edge_index, edge_attr, boundary_mask, boundary_values self, x, edge_index, edge_attr, boundary_mask, boundary_values
): ):
out = self.model(x, edge_index, edge_attr)
out = x + self.model(x, edge_index, edge_attr) out[boundary_mask] = boundary_values.unsqueeze(-1)
# out[boundary_mask] = boundary_values.unsqueeze(-1) # print(torch.min(out), torch.max(out))
plt.figure()
return out return out
def _check_convergence(self, out, x):
residual_norm = torch.norm(out - x)
if residual_norm < self.threshold * torch.norm(x):
return True
return False
def _preprocess_batch(self, batch: Batch): def _preprocess_batch(self, batch: Batch):
x, y, c, edge_index, edge_attr = ( x, y, c, edge_index, edge_attr = (
batch.x, batch.x,
@@ -137,9 +116,10 @@ class GraphSolver(LightningModule):
batch.edge_index, batch.edge_index,
batch.edge_attr, batch.edge_attr,
) )
# edge_attr = 1 / edge_attr edge_attr = 1 / edge_attr
c_ij = self._compute_c_ij(c, edge_index) c_ij = self._compute_c_ij(c, edge_index)
edge_attr = edge_attr * (c_ij) # / 100) edge_attr = edge_attr * c_ij
# edge_attr = edge_attr / torch.max(edge_attr)
return x, y, edge_index, edge_attr return x, y, edge_index, edge_attr
def training_step(self, batch: Batch): def training_step(self, batch: Batch):
@@ -171,10 +151,9 @@ class GraphSolver(LightningModule):
# plt.scatter(pos[boundary_mask,0].cpu(), pos[boundary_mask,1].cpu(), c=boundary_values.cpu(), s=1) # plt.scatter(pos[boundary_mask,0].cpu(), pos[boundary_mask,1].cpu(), c=boundary_values.cpu(), s=1)
# plt.savefig("boundary_nodes.png", dpi=300) # plt.savefig("boundary_nodes.png", dpi=300)
# y = z # y = z
print(y.shape) scale = 50
for i in range(self.current_unrolling_steps * self.inner_steps): for i in range(self.unrolling_steps):
out = self._compute_model_steps( out = self._compute_model_steps(
# torch.cat([x,pos], dim=-1),
x, x,
edge_index, edge_index,
edge_attr, edge_attr,
@@ -185,8 +164,7 @@ class GraphSolver(LightningModule):
x = out x = out
# print(out.shape, y[:, i, :].shape) # print(out.shape, y[:, i, :].shape)
losses.append(self.loss(out.flatten(), y[:, i, :].flatten())) losses.append(self.loss(out.flatten(), y[:, i, :].flatten()))
# print(self.model.scale_edge_attr.item())
print(losses)
loss = torch.stack(losses).mean() loss = torch.stack(losses).mean()
# for param in self.model.parameters(): # for param in self.model.parameters():
@@ -195,51 +173,13 @@ class GraphSolver(LightningModule):
self._log_loss(loss, batch, "train") self._log_loss(loss, batch, "train")
return loss return loss
# def on_train_epoch_start(self):
# print(f"Current unrolling steps: {self.current_unrolling_steps}, dataset unrolling steps: {self.trainer.datamodule.train_dataset.unrolling_steps}")
# return super().on_train_epoch_start()
def on_train_epoch_end(self):
if (
(self.current_epoch + 1) % self.increase_every == 0
and self.current_epoch > 0
):
dm = self.trainer.datamodule
self.current_unrolling_steps = min(
int(self.current_unrolling_steps * self.increase_rate),
self.max_unrolling_steps
)
dm.unrolling_steps = self.current_unrolling_steps
return super().on_train_epoch_end()
def validation_step(self, batch: Batch, _):
# x, y, edge_index, edge_attr = self._preprocess_batch(batch)
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
# for i in range(self.max_inference_iters * self.inner_steps):
# out = self._compute_model_steps(
# x,
# edge_index,
# edge_attr,
# deg,
# batch.boundary_mask,
# batch.boundary_values,
# )
# converged = self._check_convergence(out, x)
# x = out
# if converged:
# break
# print(y.shape, out.shape)
# loss = self.loss(out, y[:,-1,:])
# self._log_loss(loss, batch, "val")
# self.log("val/iterations", i + 1, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(batch.num_graphs),)
# return loss
def validation_step(self, batch: Batch, batch_idx):
x, y, edge_index, edge_attr = self._preprocess_batch(batch) x, y, edge_index, edge_attr = self._preprocess_batch(batch)
# deg = self._compute_deg(edge_index, edge_attr, x.size(0)) # deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = [] losses = []
pos = batch.pos pos = batch.pos
for i in range(self.current_unrolling_steps * self.inner_steps): for i in range(self.unrolling_steps):
out = self._compute_model_steps( out = self._compute_model_steps(
# torch.cat([x,pos], dim=-1), # torch.cat([x,pos], dim=-1),
x, x,
@@ -249,6 +189,7 @@ class GraphSolver(LightningModule):
batch.boundary_mask, batch.boundary_mask,
batch.boundary_values, batch.boundary_values,
) )
if (batch_idx == 0 and self.current_epoch % 10 == 0 and self.current_epoch > 20):
_plot_mesh(batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch) _plot_mesh(batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch)
x = out x = out
losses.append(self.loss(out , y[:, i, :])) losses.append(self.loss(out , y[:, i, :]))
@@ -258,41 +199,8 @@ class GraphSolver(LightningModule):
return loss return loss
def test_step(self, batch: Batch, batch_idx): def test_step(self, batch: Batch, batch_idx):
x, y, edge_index, edge_attr = self._preprocess_batch(batch) pass
deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = []
for i in range(self.max_iters):
out = self._compute_model_steps(
x,
edge_index,
edge_attr.unsqueeze(-1),
deg,
batch.boundary_mask,
batch.boundary_values,
)
converged = self._check_convergence(out, x)
# _plot_mesh(batch.pos, y, out, batch.batch, i, batch_idx)
losses.append(self.loss(out, y).item())
if converged:
break
x = out
loss = self.loss(out, y)
# _plot_losses(losses, batch_idx)
self._log_loss(loss, batch, "test")
self.log(
"test/iterations",
i + 1,
on_step=False,
on_epoch=True,
prog_bar=True,
batch_size=int(batch.num_graphs),
)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-2) optimizer = torch.optim.AdamW(self.parameters(), lr=5e-3)
return optimizer return optimizer
def _impose_bc(self, x: torch.Tensor, data: Batch):
x[data.boundary_mask] = data.boundary_values
return x

View File

@@ -85,7 +85,7 @@ class GraphDataModule(LightningDataModule):
conductivity = torch.tensor( conductivity = torch.tensor(
geometry["conductivity"], dtype=torch.float32 geometry["conductivity"], dtype=torch.float32
) )
temperatures = torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:2] temperatures = torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:40]
times = torch.tensor(snapshot["times"], dtype=torch.float32) times = torch.tensor(snapshot["times"], dtype=torch.float32)
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
@@ -131,9 +131,7 @@ class GraphDataModule(LightningDataModule):
data = [] data = []
for i in range(n_data): for i in range(n_data):
x = temperatures[i, :].unsqueeze(-1) x = temperatures[i, :].unsqueeze(-1)
print(x.shape)
y = temperatures[i + 1 : i + 1 + self.unrolling_steps, :].unsqueeze(-1).permute(1,0,2) y = temperatures[i + 1 : i + 1 + self.unrolling_steps, :].unsqueeze(-1).permute(1,0,2)
# print(y.shape)
data.append(MeshData( data.append(MeshData(
x=x, x=x,
y=y, y=y,
@@ -187,9 +185,9 @@ class GraphDataModule(LightningDataModule):
def train_dataloader(self): def train_dataloader(self):
# ds = self.create_autoregressive_datasets(dataset="train") # ds = self.create_autoregressive_datasets(dataset="train")
# self.train_dataset = ds # self.train_dataset = ds
print(type(self.train_data[0])) # print(type(self.train_data[0]))
ds = [i for data in self.train_data for i in data] ds = [i for data in self.train_data for i in data]
print(type(ds[0])) # print(type(ds[0]))
return DataLoader( return DataLoader(
ds, ds,
batch_size=self.batch_size, batch_size=self.batch_size,
@@ -202,7 +200,7 @@ class GraphDataModule(LightningDataModule):
ds = [i for data in self.val_data for i in data] ds = [i for data in self.val_data for i in data]
return DataLoader( return DataLoader(
ds, ds,
batch_size=self.batch_size, batch_size=128,
shuffle=False, shuffle=False,
num_workers=8, num_workers=8,
pin_memory=True, pin_memory=True,

View File

@@ -0,0 +1,100 @@
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
class DiffusionLayer(MessagePassing):
"""
Modella: T_new = T_old + dt * Divergenza(Flusso)
"""
def __init__(
self,
channels: int,
**kwargs,
):
super().__init__(aggr='add', **kwargs)
self.dt = nn.Parameter(torch.tensor(1e-4))
self.conductivity_net = nn.Sequential(
nn.Linear(channels, channels, bias=False),
nn.GELU(),
nn.Linear(channels, channels, bias=False),
)
self.phys_encoder = nn.Sequential(
nn.Linear(1, 8, bias=False),
nn.Tanh(),
nn.Linear(8, 1, bias=False),
nn.Softplus()
)
def forward(self, x, edge_index, edge_weight):
edge_weight = edge_weight.unsqueeze(-1)
conductance = self.phys_encoder(edge_weight)
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
return x + (net_flux * self.dt)
def message(self, x_i, x_j, conductance):
delta = x_j - x_i
flux = delta * conductance
flux = flux + self.conductivity_net(flux)
return flux
class DiffusionNet(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4):
super().__init__()
# Encoder: Projects input temperature to hidden feature space
self.enc = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.GELU(),
)
self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
# Scale parameters for conditioning
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# Stack of Diffusion Layers
self.layers = torch.nn.ModuleList(
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
)
# Decoder: Projects hidden features back to Temperature space
self.dec = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Linear(hidden_dim, output_dim, bias=True),
nn.Softplus(), # Ensure positive temperature output
)
self.func = torch.nn.GELU()
def forward(self, x, edge_index, edge_attr):
# 1. Global Residual Connection setup
# We save the input to add it back at the very end.
# The network learns the correction (Delta T), not the absolute T.
x_input = x
# 2. Encode
h = self.enc(x) * torch.exp(self.scale_x)
# Scale edge attributes (learnable gating of physical conductivity)
w = edge_attr * torch.exp(self.scale_edge_attr)
# 4. Message Passing (Diffusion Steps)
for layer in self.layers:
# h is updated internally via residual connection in DiffusionLayer
h = layer(h, edge_index, w)
h = self.func(h)
# 5. Decode
delta_x = self.dec(h)
# 6. Final Update (Explicit Euler Step)
# T_new = T_old + Correction
# return x_input + delta_x
return delta_ddx

View File

@@ -2,68 +2,209 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch_geometric.nn import MessagePassing from torch_geometric.nn import MessagePassing
from torch.nn.utils import spectral_norm from torch.nn.utils import spectral_norm
from torch_geometric.nn.conv import GCNConv from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
class GCNConvLayer(MessagePassing): # class GCNConvLayer(MessagePassing):
def __init__(self, in_channels, out_channels): # def __init__(
super().__init__(aggr="add") # self,
self.lin_l = nn.Linear(in_channels, out_channels, bias=True) # in_channels,
# self.lin_r = spectral_norm(nn.Linear(in_channels, out_channels, bias=False)) # out_channels,
# aggr: str = 'mean',
# bias: bool = True,
# **kwargs,
# ):
# super().__init__(aggr=aggr, **kwargs)
def forward(self, x, edge_index, edge_attr, deg): # self.in_channels = in_channels
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, deg=deg) # self.out_channels = out_channels
out = self.lin_l(out)
return out
def message(self, x_j, edge_attr): # if isinstance(in_channels, int):
return x_j * edge_attr.view(-1, 1) # in_channels = (in_channels, in_channels)
def aggregate(self, inputs, index, deg): # self.lin_rel = nn.Linear(in_channels[0], out_channels, bias=bias)
# self.lin_root = nn.Linear(in_channels[1], out_channels, bias=False)
# self.reset_parameters()
# def reset_parameters(self):
# super().reset_parameters()
# self.lin_rel.reset_parameters()
# self.lin_root.reset_parameters()
# def forward(self, x, edge_index,
# edge_weight = None, size = None):
# edge_weight = self.normalize(edge_weight, edge_index, x.size(0), dtype=x.dtype)
# out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
# size=size)
# out = self.lin_rel(out)
# out = out + self.lin_root(x)
# return out
# def message(self, x_j, edge_weight):
# return x_j * edge_weight.view(-1, 1)
# @staticmethod
# def normalize(edge_weights, edge_index, num_nodes, dtype=None):
# """Symmetrically normalize edge weights."""
# if dtype is None:
# dtype = edge_weights.dtype
# device = edge_index.device
# row, col = edge_index
# deg = torch.zeros(num_nodes, device=device, dtype=dtype)
# deg = deg.scatter_add(0, row, edge_weights)
# deg_inv_sqrt = deg.pow(-0.5)
# deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
# return deg_inv_sqrt[row] * edge_weights * deg_inv_sqrt[col]
# class CorrectionNet(nn.Module):
# def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8):
# super().__init__()
# self.enc = nn.Linear(input_dim, hidden_dim, bias=True),
# self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
# self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# self.layers = torch.nn.ModuleList(
# [GCNConv(hidden_dim, hidden_dim, aggr="mean") for _ in range(n_layers)]
# )
# self.dec = nn.Linear(hidden_dim, output_dim, bias=True),
# self.func = torch.nn.GELU()
# def forward(self, x, edge_index, edge_attr,):
# h = self.enc(x) # * torch.exp(self.scale_x)
# edge_attr = edge_attr # * torch.exp(self.scale_edge_attr)
# h = self.func(h)
# for l in self.layers:
# h = l(h, edge_index, edge_attr)
# h = self.func(h)
# out = self.dec(h)
# return out
# class MLPNet(nn.Module):
# def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=1):
# super().__init__()
# layers = []
# func = torch.nn.ReLU
# self.network = nn.Sequential(
# nn.Linear(input_dim, hidden_dim),
# func(),
# nn.Linear(hidden_dim, hidden_dim),
# func(),
# nn.Linear(hidden_dim, hidden_dim),
# func(),
# nn.Linear(hidden_dim, output_dim),
# )
# def forward(self, x, edge_index=None, edge_attr=None):
# return self.network(x)
# import torch
# import torch.nn as nn
# from torch_geometric.nn import MessagePassing
# import torch
# import torch.nn as nn
# from torch_geometric.nn import MessagePassing
class DiffusionLayer(MessagePassing):
""" """
TODO: add docstring. Modella: T_new = T_old + dt * Divergenza(Flusso)
""" """
out = super().aggregate(inputs, index) def __init__(
deg = deg + 1e-7 self,
return out / deg.view(-1, 1) channels: int,
**kwargs,
):
super().__init__(aggr='add', **kwargs)
self.dt = nn.Parameter(torch.tensor(1e-4))
self.conductivity_net = nn.Sequential(
nn.Linear(channels, channels, bias=False),
nn.GELU(),
nn.Linear(channels, channels, bias=False),
)
self.phys_encoder = nn.Sequential(
nn.Linear(1, 8, bias=False),
nn.Tanh(),
nn.Linear(8, 1, bias=False),
nn.Softplus()
)
def forward(self, x, edge_index, edge_weight):
edge_weight = edge_weight.unsqueeze(-1)
conductance = self.phys_encoder(edge_weight)
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
return x + (net_flux * self.dt)
def message(self, x_i, x_j, conductance):
delta = x_j - x_i
flux = delta * conductance
flux = flux + self.conductivity_net(flux)
return flux
class CorrectionNet(nn.Module): class CorrectionNet(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8): def __init__(self, input_dim=1, output_dim=1, hidden_dim=32, n_layers=4):
super().__init__() super().__init__()
self.enc = nn.Linear(input_dim, hidden_dim, bias=False)
# self.layers = n_layers # Encoder: Projects input temperature to hidden feature space
# self.l = GCNConv(hidden_dim, hidden_dim, aggr="mean") self.enc = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.GELU(),
)
self.scale_x = nn.Parameter(torch.zeros(hidden_dim))
# Scale parameters for conditioning
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
# Stack of Diffusion Layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[GCNConv(hidden_dim, hidden_dim, aggr="mean", bias=False) for _ in range(n_layers)] [DiffusionLayer(hidden_dim) for _ in range(n_layers)]
)
self.dec = nn.Linear(hidden_dim, output_dim)
def forward(self, x, edge_index, edge_attr,):
h = self.enc(x)
# h = self.relu(h)
for l in self.layers:
# print(f"Forward pass layer {_}")
h = l(h, edge_index, edge_attr)
# h = self.relu(h)
out = self.dec(h)
return out
class MLPNet(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=1):
super().__init__()
layers = []
func = torch.nn.ReLU
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
func(),
nn.Linear(hidden_dim, hidden_dim),
func(),
nn.Linear(hidden_dim, hidden_dim),
func(),
nn.Linear(hidden_dim, output_dim),
) )
def forward(self, x, edge_index=None, edge_attr=None): # Decoder: Projects hidden features back to Temperature space
return self.network(x) self.dec = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Linear(hidden_dim, output_dim, bias=True),
nn.Softplus(), # Ensure positive temperature output
)
self.func = torch.nn.GELU()
def forward(self, x, edge_index, edge_attr):
# 1. Global Residual Connection setup
# We save the input to add it back at the very end.
# The network learns the correction (Delta T), not the absolute T.
x_input = x
# 2. Encode
h = self.enc(x) * torch.exp(self.scale_x)
# Scale edge attributes (learnable gating of physical conductivity)
w = edge_attr * torch.exp(self.scale_edge_attr)
# 4. Message Passing (Diffusion Steps)
for layer in self.layers:
# h is updated internally via residual connection in DiffusionLayer
h = layer(h, edge_index, w)
h = self.func(h)
# 5. Decode
delta_x = self.dec(h)
# 6. Final Update (Explicit Euler Step)
# T_new = T_old + Correction
# return x_input + delta_x
return delta_x