fix model
This commit is contained in:
@@ -7,6 +7,7 @@ from matplotlib.tri import Triangulation
|
|||||||
from .model.finite_difference import FiniteDifferenceStep
|
from .model.finite_difference import FiniteDifferenceStep
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def import_class(class_path: str):
|
def import_class(class_path: str):
|
||||||
module_path, class_name = class_path.rsplit(".", 1) # split last dot
|
module_path, class_name = class_path.rsplit(".", 1) # split last dot
|
||||||
module = importlib.import_module(module_path) # import the module
|
module = importlib.import_module(module_path) # import the module
|
||||||
@@ -14,7 +15,7 @@ def import_class(class_path: str):
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx):
|
def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
|
||||||
for j in [0, 10, 20, 30]:
|
for j in [0, 10, 20, 30]:
|
||||||
idx = (batch == j).nonzero(as_tuple=True)[0]
|
idx = (batch == j).nonzero(as_tuple=True)[0]
|
||||||
y = y_[idx].detach().cpu()
|
y = y_[idx].detach().cpu()
|
||||||
@@ -49,6 +50,7 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx):
|
|||||||
plt.savefig(name, dpi=72)
|
plt.savefig(name, dpi=72)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def _plot_losses(losses, batch_idx):
|
def _plot_losses(losses, batch_idx):
|
||||||
folder = f"{batch_idx:02d}_images"
|
folder = f"{batch_idx:02d}_images"
|
||||||
plt.figure()
|
plt.figure()
|
||||||
@@ -74,8 +76,8 @@ class GraphSolver(LightningModule):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = import_class(model_class_path)(**model_init_args)
|
self.model = import_class(model_class_path)(**model_init_args)
|
||||||
# for param in self.model.parameters():
|
# for param in self.model.parameters():
|
||||||
# print(f"Param: {param.shape}, Grad: {param.grad}")
|
# print(f"Param: {param.shape}, Grad: {param.grad}")
|
||||||
# print(f"Param: {param[0]}")
|
# print(f"Param: {param[0]}")
|
||||||
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
||||||
self.unrolling_steps = unrolling_steps
|
self.unrolling_steps = unrolling_steps
|
||||||
|
|
||||||
@@ -101,29 +103,36 @@ class GraphSolver(LightningModule):
|
|||||||
return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze()
|
return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze()
|
||||||
|
|
||||||
def _compute_model_steps(
|
def _compute_model_steps(
|
||||||
self, x, edge_index, edge_attr, boundary_mask, boundary_values
|
self,
|
||||||
):
|
x,
|
||||||
out = self.model(x, edge_index, edge_attr)
|
edge_index,
|
||||||
|
edge_attr,
|
||||||
|
boundary_mask,
|
||||||
|
boundary_values,
|
||||||
|
conductivity,
|
||||||
|
):
|
||||||
|
out = self.model(x, edge_index, edge_attr, conductivity)
|
||||||
out[boundary_mask] = boundary_values.unsqueeze(-1)
|
out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||||
# print(torch.min(out), torch.max(out))
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _preprocess_batch(self, batch: Batch):
|
def _preprocess_batch(self, batch: Batch):
|
||||||
x, y, c, edge_index, edge_attr = (
|
x, y, c, edge_index, edge_attr, nodal_area = (
|
||||||
batch.x,
|
batch.x,
|
||||||
batch.y,
|
batch.y,
|
||||||
batch.c,
|
batch.c,
|
||||||
batch.edge_index,
|
batch.edge_index,
|
||||||
batch.edge_attr,
|
batch.edge_attr,
|
||||||
|
batch.nodal_area,
|
||||||
)
|
)
|
||||||
edge_attr = 1 / edge_attr
|
edge_attr = 1 / edge_attr
|
||||||
c_ij = self._compute_c_ij(c, edge_index)
|
conductivity = self._compute_c_ij(c, edge_index)
|
||||||
edge_attr = edge_attr * c_ij
|
edge_attr = edge_attr * conductivity
|
||||||
# edge_attr = edge_attr / torch.max(edge_attr)
|
return x, y, edge_index, edge_attr, conductivity
|
||||||
return x, y, edge_index, edge_attr
|
|
||||||
|
|
||||||
def training_step(self, batch: Batch):
|
def training_step(self, batch: Batch):
|
||||||
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
|
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
|
||||||
|
batch
|
||||||
|
)
|
||||||
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
||||||
losses = []
|
losses = []
|
||||||
# print(x.shape, y.shape)
|
# print(x.shape, y.shape)
|
||||||
@@ -160,12 +169,13 @@ class GraphSolver(LightningModule):
|
|||||||
# deg,
|
# deg,
|
||||||
batch.boundary_mask,
|
batch.boundary_mask,
|
||||||
batch.boundary_values,
|
batch.boundary_values,
|
||||||
|
conductivity,
|
||||||
)
|
)
|
||||||
x = out
|
x = out
|
||||||
# print(out.shape, y[:, i, :].shape)
|
# print(out.shape, y[:, i, :].shape)
|
||||||
losses.append(self.loss(out.flatten(), y[:, i, :].flatten()))
|
losses.append(self.loss(out.flatten(), y[:, i, :].flatten()))
|
||||||
# print(self.model.scale_edge_attr.item())
|
# print(self.model.scale_edge_attr.item())
|
||||||
|
|
||||||
loss = torch.stack(losses).mean()
|
loss = torch.stack(losses).mean()
|
||||||
# for param in self.model.parameters():
|
# for param in self.model.parameters():
|
||||||
# print(f"Param: {param.shape}, Grad: {param.grad}")
|
# print(f"Param: {param.shape}, Grad: {param.grad}")
|
||||||
@@ -173,26 +183,40 @@ class GraphSolver(LightningModule):
|
|||||||
self._log_loss(loss, batch, "train")
|
self._log_loss(loss, batch, "train")
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def validation_step(self, batch: Batch, batch_idx):
|
def validation_step(self, batch: Batch, batch_idx):
|
||||||
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
|
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
|
||||||
|
batch
|
||||||
|
)
|
||||||
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
||||||
losses = []
|
losses = []
|
||||||
pos = batch.pos
|
pos = batch.pos
|
||||||
for i in range(self.unrolling_steps):
|
for i in range(self.unrolling_steps):
|
||||||
out = self._compute_model_steps(
|
out = self._compute_model_steps(
|
||||||
# torch.cat([x,pos], dim=-1),
|
# torch.cat([x,pos], dim=-1),
|
||||||
x,
|
x,
|
||||||
edge_index,
|
edge_index,
|
||||||
edge_attr,
|
edge_attr,
|
||||||
# deg,
|
# deg,
|
||||||
batch.boundary_mask,
|
batch.boundary_mask,
|
||||||
batch.boundary_values,
|
batch.boundary_values,
|
||||||
|
conductivity,
|
||||||
)
|
)
|
||||||
if (batch_idx == 0 and self.current_epoch % 10 == 0 and self.current_epoch > 20):
|
if (
|
||||||
_plot_mesh(batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch)
|
batch_idx == 0
|
||||||
|
and self.current_epoch % 10 == 0
|
||||||
|
and self.current_epoch > 0
|
||||||
|
):
|
||||||
|
_plot_mesh(
|
||||||
|
batch.pos,
|
||||||
|
x,
|
||||||
|
out,
|
||||||
|
y[:, i, :],
|
||||||
|
batch.batch,
|
||||||
|
i,
|
||||||
|
self.current_epoch,
|
||||||
|
)
|
||||||
x = out
|
x = out
|
||||||
losses.append(self.loss(out , y[:, i, :]))
|
losses.append(self.loss(out, y[:, i, :]))
|
||||||
|
|
||||||
loss = torch.stack(losses).mean()
|
loss = torch.stack(losses).mean()
|
||||||
self._log_loss(loss, batch, "val")
|
self._log_loss(loss, batch, "val")
|
||||||
@@ -202,5 +226,5 @@ class GraphSolver(LightningModule):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.AdamW(self.parameters(), lr=5e-3)
|
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|||||||
@@ -6,7 +6,39 @@ from torch_geometric.data import Data
|
|||||||
from torch_geometric.loader import DataLoader
|
from torch_geometric.loader import DataLoader
|
||||||
from torch_geometric.utils import to_undirected
|
from torch_geometric.utils import to_undirected
|
||||||
from .mesh_data import MeshData
|
from .mesh_data import MeshData
|
||||||
|
|
||||||
# from torch.utils.data import Dataset
|
# from torch.utils.data import Dataset
|
||||||
|
from torch_geometric.utils import scatter
|
||||||
|
|
||||||
|
|
||||||
|
def compute_nodal_area(edge_index, edge_attr, num_nodes):
|
||||||
|
"""
|
||||||
|
1. Calculates Area ~ (Min Edge Length)^2
|
||||||
|
2. Scales by Mean so average cell has size 1.0
|
||||||
|
"""
|
||||||
|
row, col = edge_index
|
||||||
|
dist = edge_attr.squeeze()
|
||||||
|
|
||||||
|
# 1. Get 'h' (Closest neighbor distance)
|
||||||
|
# Using 'min' filters out diagonal connections in the quad mesh
|
||||||
|
h = scatter(dist, col, dim=0, dim_size=num_nodes, reduce="min")
|
||||||
|
|
||||||
|
# 2. Estimate Raw Area
|
||||||
|
raw_area = h.pow(2)
|
||||||
|
|
||||||
|
# 3. Mean Scaling (The Best Normalization)
|
||||||
|
# This keeps values near 1.0, preserving stability AND physics ratios.
|
||||||
|
# We detach to ensure no gradients flow here (it's static data).
|
||||||
|
mean_val = raw_area.mean().detach()
|
||||||
|
|
||||||
|
# Result:
|
||||||
|
# Small cells -> approx 0.1
|
||||||
|
# Large cells -> approx 5.0
|
||||||
|
# Average -> 1.0
|
||||||
|
# nodal_area = (raw_area / mean_val).unsqueeze(-1) + 1e-6
|
||||||
|
nodal_area = raw_area
|
||||||
|
return nodal_area.unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class GraphDataModule(LightningDataModule):
|
class GraphDataModule(LightningDataModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -26,7 +58,11 @@ class GraphDataModule(LightningDataModule):
|
|||||||
self.hf_repo = hf_repo
|
self.hf_repo = hf_repo
|
||||||
self.split_name = split_name
|
self.split_name = split_name
|
||||||
self.dataset_dict = {}
|
self.dataset_dict = {}
|
||||||
self.train_dataset, self.val_dataset, self.test_dataset = None, None, None
|
self.train_dataset, self.val_dataset, self.test_dataset = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
self.unrolling_steps = start_unrolling_steps
|
self.unrolling_steps = start_unrolling_steps
|
||||||
self.geometry_dict = {}
|
self.geometry_dict = {}
|
||||||
self.train_size = train_size
|
self.train_size = train_size
|
||||||
@@ -85,7 +121,9 @@ class GraphDataModule(LightningDataModule):
|
|||||||
conductivity = torch.tensor(
|
conductivity = torch.tensor(
|
||||||
geometry["conductivity"], dtype=torch.float32
|
geometry["conductivity"], dtype=torch.float32
|
||||||
)
|
)
|
||||||
temperatures = torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:40]
|
temperatures = torch.tensor(
|
||||||
|
snapshot["temperatures"], dtype=torch.float32
|
||||||
|
)[:40]
|
||||||
times = torch.tensor(snapshot["times"], dtype=torch.float32)
|
times = torch.tensor(snapshot["times"], dtype=torch.float32)
|
||||||
|
|
||||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||||
@@ -100,16 +138,19 @@ class GraphDataModule(LightningDataModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.build_radial_graph:
|
if self.build_radial_graph:
|
||||||
from pina.graph import RadiusGraph
|
# from pina.graph import RadiusGraph
|
||||||
|
|
||||||
if self.radius is None:
|
# if self.radius is None:
|
||||||
raise ValueError("Radius must be specified for radial graph.")
|
# raise ValueError("Radius must be specified for radial graph.")
|
||||||
edge_index = RadiusGraph.compute_radius_graph(
|
# edge_index = RadiusGraph.compute_radius_graph(
|
||||||
pos, radius=self.radius
|
# pos, radius=self.radius
|
||||||
|
# )
|
||||||
|
# from torch_geometric.utils import remove_self_loops
|
||||||
|
|
||||||
|
# edge_index, _ = remove_self_loops(edge_index)
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Radial graph building not implemented yet."
|
||||||
)
|
)
|
||||||
from torch_geometric.utils import remove_self_loops
|
|
||||||
|
|
||||||
edge_index, _ = remove_self_loops(edge_index)
|
|
||||||
else:
|
else:
|
||||||
edge_index = torch.tensor(
|
edge_index = torch.tensor(
|
||||||
geometry["edge_index"], dtype=torch.int64
|
geometry["edge_index"], dtype=torch.int64
|
||||||
@@ -117,31 +158,37 @@ class GraphDataModule(LightningDataModule):
|
|||||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||||
|
|
||||||
boundary_mask, boundary_values = self._compute_boundary_mask(
|
boundary_mask, boundary_values = self._compute_boundary_mask(
|
||||||
bottom_ids, right_ids, top_ids, left_ids, temperatures[0,:]
|
bottom_ids, right_ids, top_ids, left_ids, temperatures[0, :]
|
||||||
)
|
)
|
||||||
|
edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
|
||||||
|
nodal_area = compute_nodal_area(edge_index, edge_attr, pos.size(0))
|
||||||
if self.remove_boundary_edges:
|
if self.remove_boundary_edges:
|
||||||
boundary_idx = torch.unique(boundary_mask)
|
boundary_idx = torch.unique(boundary_mask)
|
||||||
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
||||||
edge_index = edge_index[:, edge_index_mask]
|
edge_index = edge_index[:, edge_index_mask]
|
||||||
|
edge_attr = edge_attr[edge_index_mask]
|
||||||
edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
|
|
||||||
|
|
||||||
n_data = temperatures.size(0) - self.unrolling_steps
|
n_data = temperatures.size(0) - self.unrolling_steps
|
||||||
data = []
|
data = []
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
x = temperatures[i, :].unsqueeze(-1)
|
x = temperatures[i, :].unsqueeze(-1)
|
||||||
y = temperatures[i + 1 : i + 1 + self.unrolling_steps, :].unsqueeze(-1).permute(1,0,2)
|
y = (
|
||||||
data.append(MeshData(
|
temperatures[i + 1 : i + 1 + self.unrolling_steps, :]
|
||||||
x=x,
|
.unsqueeze(-1)
|
||||||
y=y,
|
.permute(1, 0, 2)
|
||||||
c=conductivity.unsqueeze(-1),
|
)
|
||||||
edge_index=edge_index,
|
data.append(
|
||||||
pos=pos,
|
MeshData(
|
||||||
edge_attr=edge_attr,
|
x=x,
|
||||||
boundary_mask=boundary_mask,
|
y=y,
|
||||||
boundary_values=boundary_values,
|
c=conductivity.unsqueeze(-1),
|
||||||
))
|
edge_index=edge_index,
|
||||||
|
pos=pos,
|
||||||
|
edge_attr=edge_attr,
|
||||||
|
boundary_mask=boundary_mask,
|
||||||
|
boundary_values=boundary_values,
|
||||||
|
nodal_area=nodal_area,
|
||||||
|
)
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def setup(self, stage: str = None):
|
def setup(self, stage: str = None):
|
||||||
@@ -207,7 +254,9 @@ class GraphDataModule(LightningDataModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
ds = self.create_autoregressive_datasets(dataset="test", no_unrolling=True)
|
ds = self.create_autoregressive_datasets(
|
||||||
|
dataset="test", no_unrolling=True
|
||||||
|
)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
ds,
|
ds,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from matplotlib.tri import Triangulation
|
|||||||
from .model.finite_difference import FiniteDifferenceStep
|
from .model.finite_difference import FiniteDifferenceStep
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def import_class(class_path: str):
|
def import_class(class_path: str):
|
||||||
module_path, class_name = class_path.rsplit(".", 1) # split last dot
|
module_path, class_name = class_path.rsplit(".", 1) # split last dot
|
||||||
module = importlib.import_module(module_path) # import the module
|
module = importlib.import_module(module_path) # import the module
|
||||||
@@ -43,6 +44,7 @@ def _plot_mesh(pos, y, y_pred, batch, i, batch_idx):
|
|||||||
plt.savefig(name, dpi=72)
|
plt.savefig(name, dpi=72)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def _plot_losses(losses, batch_idx):
|
def _plot_losses(losses, batch_idx):
|
||||||
folder = f"{batch_idx:02d}_images"
|
folder = f"{batch_idx:02d}_images"
|
||||||
plt.figure()
|
plt.figure()
|
||||||
|
|||||||
@@ -2,37 +2,39 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch_geometric.nn import MessagePassing
|
from torch_geometric.nn import MessagePassing
|
||||||
|
|
||||||
|
|
||||||
class DiffusionLayer(MessagePassing):
|
class DiffusionLayer(MessagePassing):
|
||||||
"""
|
"""
|
||||||
Modella: T_new = T_old + dt * Divergenza(Flusso)
|
Modella: T_new = T_old + dt * Divergenza(Flusso)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(aggr='add', **kwargs)
|
super().__init__(aggr="add", **kwargs)
|
||||||
|
|
||||||
self.dt = nn.Parameter(torch.tensor(1e-4))
|
self.dt = nn.Parameter(torch.tensor(1e-4))
|
||||||
self.conductivity_net = nn.Sequential(
|
self.conductivity_net = nn.Sequential(
|
||||||
nn.Linear(channels, channels, bias=False),
|
nn.Linear(channels, channels, bias=False),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(channels, channels, bias=False),
|
nn.Linear(channels, channels, bias=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.phys_encoder = nn.Sequential(
|
self.phys_encoder = nn.Sequential(
|
||||||
nn.Linear(1, 8, bias=False),
|
nn.Linear(1, 8, bias=False),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
nn.Linear(8, 1, bias=False),
|
nn.Linear(8, 1, bias=False),
|
||||||
nn.Softplus()
|
nn.Softplus(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, edge_index, edge_weight):
|
def forward(self, x, edge_index, edge_weight, conductivity):
|
||||||
edge_weight = edge_weight.unsqueeze(-1)
|
edge_weight = edge_weight.unsqueeze(-1)
|
||||||
conductance = self.phys_encoder(edge_weight)
|
conductance = self.phys_encoder(edge_weight)
|
||||||
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
|
net_flux = self.propagate(edge_index, x=x, conductance=conductance)
|
||||||
return x + (net_flux * self.dt)
|
return x + ((net_flux) * self.dt)
|
||||||
|
|
||||||
def message(self, x_i, x_j, conductance):
|
def message(self, x_i, x_j, conductance):
|
||||||
delta = x_j - x_i
|
delta = x_j - x_i
|
||||||
@@ -44,7 +46,7 @@ class DiffusionLayer(MessagePassing):
|
|||||||
class DiffusionNet(nn.Module):
|
class DiffusionNet(nn.Module):
|
||||||
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4):
|
def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Encoder: Projects input temperature to hidden feature space
|
# Encoder: Projects input temperature to hidden feature space
|
||||||
self.enc = nn.Sequential(
|
self.enc = nn.Sequential(
|
||||||
nn.Linear(input_dim, hidden_dim, bias=True),
|
nn.Linear(input_dim, hidden_dim, bias=True),
|
||||||
@@ -57,12 +59,12 @@ class DiffusionNet(nn.Module):
|
|||||||
|
|
||||||
# Scale parameters for conditioning
|
# Scale parameters for conditioning
|
||||||
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
|
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
# Stack of Diffusion Layers
|
# Stack of Diffusion Layers
|
||||||
self.layers = torch.nn.ModuleList(
|
self.layers = torch.nn.ModuleList(
|
||||||
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
|
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decoder: Projects hidden features back to Temperature space
|
# Decoder: Projects hidden features back to Temperature space
|
||||||
self.dec = nn.Sequential(
|
self.dec = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
||||||
@@ -73,28 +75,28 @@ class DiffusionNet(nn.Module):
|
|||||||
|
|
||||||
self.func = torch.nn.GELU()
|
self.func = torch.nn.GELU()
|
||||||
|
|
||||||
def forward(self, x, edge_index, edge_attr):
|
def forward(self, x, edge_index, edge_attr, conductivity):
|
||||||
# 1. Global Residual Connection setup
|
# 1. Global Residual Connection setup
|
||||||
# We save the input to add it back at the very end.
|
# We save the input to add it back at the very end.
|
||||||
# The network learns the correction (Delta T), not the absolute T.
|
# The network learns the correction (Delta T), not the absolute T.
|
||||||
x_input = x
|
x_input = x
|
||||||
|
|
||||||
# 2. Encode
|
# 2. Encode
|
||||||
h = self.enc(x) * torch.exp(self.scale_x)
|
h = self.enc(x) * torch.exp(self.scale_x)
|
||||||
|
|
||||||
# Scale edge attributes (learnable gating of physical conductivity)
|
# Scale edge attributes (learnable gating of physical conductivity)
|
||||||
w = edge_attr * torch.exp(self.scale_edge_attr)
|
w = edge_attr * torch.exp(self.scale_edge_attr)
|
||||||
|
|
||||||
# 4. Message Passing (Diffusion Steps)
|
# 4. Message Passing (Diffusion Steps)
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
# h is updated internally via residual connection in DiffusionLayer
|
# h is updated internally via residual connection in DiffusionLayer
|
||||||
h = layer(h, edge_index, w)
|
h = layer(h, edge_index, w, conductivity)
|
||||||
h = self.func(h)
|
h = self.func(h)
|
||||||
|
|
||||||
# 5. Decode
|
# 5. Decode
|
||||||
delta_x = self.dec(h)
|
delta_x = self.dec(h)
|
||||||
|
|
||||||
# 6. Final Update (Explicit Euler Step)
|
# 6. Final Update (Explicit Euler Step)
|
||||||
# T_new = T_old + Correction
|
# T_new = T_old + Correction
|
||||||
# return x_input + delta_x
|
# return x_input + delta_x
|
||||||
return delta_ddx
|
return delta_x
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
|
|||||||
|
|
||||||
# def message(self, x_j, edge_weight):
|
# def message(self, x_j, edge_weight):
|
||||||
# return x_j * edge_weight.view(-1, 1)
|
# return x_j * edge_weight.view(-1, 1)
|
||||||
|
|
||||||
# @staticmethod
|
# @staticmethod
|
||||||
# def normalize(edge_weights, edge_index, num_nodes, dtype=None):
|
# def normalize(edge_weights, edge_index, num_nodes, dtype=None):
|
||||||
# """Symmetrically normalize edge weights."""
|
# """Symmetrically normalize edge weights."""
|
||||||
@@ -58,7 +58,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
|
|||||||
# deg_inv_sqrt = deg.pow(-0.5)
|
# deg_inv_sqrt = deg.pow(-0.5)
|
||||||
# deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
|
# deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
|
||||||
# return deg_inv_sqrt[row] * edge_weights * deg_inv_sqrt[col]
|
# return deg_inv_sqrt[row] * edge_weights * deg_inv_sqrt[col]
|
||||||
|
|
||||||
|
|
||||||
# class CorrectionNet(nn.Module):
|
# class CorrectionNet(nn.Module):
|
||||||
# def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8):
|
# def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, n_layers=8):
|
||||||
@@ -89,7 +89,7 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
|
|||||||
# super().__init__()
|
# super().__init__()
|
||||||
# layers = []
|
# layers = []
|
||||||
# func = torch.nn.ReLU
|
# func = torch.nn.ReLU
|
||||||
|
|
||||||
# self.network = nn.Sequential(
|
# self.network = nn.Sequential(
|
||||||
# nn.Linear(input_dim, hidden_dim),
|
# nn.Linear(input_dim, hidden_dim),
|
||||||
# func(),
|
# func(),
|
||||||
@@ -112,30 +112,32 @@ from torch_geometric.nn.conv import GCNConv, SAGEConv, GatedGraphConv, GraphConv
|
|||||||
# import torch.nn as nn
|
# import torch.nn as nn
|
||||||
# from torch_geometric.nn import MessagePassing
|
# from torch_geometric.nn import MessagePassing
|
||||||
|
|
||||||
|
|
||||||
class DiffusionLayer(MessagePassing):
|
class DiffusionLayer(MessagePassing):
|
||||||
"""
|
"""
|
||||||
Modella: T_new = T_old + dt * Divergenza(Flusso)
|
Modella: T_new = T_old + dt * Divergenza(Flusso)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(aggr='add', **kwargs)
|
super().__init__(aggr="add", **kwargs)
|
||||||
|
|
||||||
self.dt = nn.Parameter(torch.tensor(1e-4))
|
self.dt = nn.Parameter(torch.tensor(1e-4))
|
||||||
self.conductivity_net = nn.Sequential(
|
self.conductivity_net = nn.Sequential(
|
||||||
nn.Linear(channels, channels, bias=False),
|
nn.Linear(channels, channels, bias=False),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(channels, channels, bias=False),
|
nn.Linear(channels, channels, bias=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.phys_encoder = nn.Sequential(
|
self.phys_encoder = nn.Sequential(
|
||||||
nn.Linear(1, 8, bias=False),
|
nn.Linear(1, 8, bias=False),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
nn.Linear(8, 1, bias=False),
|
nn.Linear(8, 1, bias=False),
|
||||||
nn.Softplus()
|
nn.Softplus(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, edge_index, edge_weight):
|
def forward(self, x, edge_index, edge_weight):
|
||||||
@@ -154,7 +156,7 @@ class DiffusionLayer(MessagePassing):
|
|||||||
class CorrectionNet(nn.Module):
|
class CorrectionNet(nn.Module):
|
||||||
def __init__(self, input_dim=1, output_dim=1, hidden_dim=32, n_layers=4):
|
def __init__(self, input_dim=1, output_dim=1, hidden_dim=32, n_layers=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Encoder: Projects input temperature to hidden feature space
|
# Encoder: Projects input temperature to hidden feature space
|
||||||
self.enc = nn.Sequential(
|
self.enc = nn.Sequential(
|
||||||
nn.Linear(input_dim, hidden_dim, bias=True),
|
nn.Linear(input_dim, hidden_dim, bias=True),
|
||||||
@@ -167,12 +169,12 @@ class CorrectionNet(nn.Module):
|
|||||||
|
|
||||||
# Scale parameters for conditioning
|
# Scale parameters for conditioning
|
||||||
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
|
self.scale_edge_attr = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
# Stack of Diffusion Layers
|
# Stack of Diffusion Layers
|
||||||
self.layers = torch.nn.ModuleList(
|
self.layers = torch.nn.ModuleList(
|
||||||
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
|
[DiffusionLayer(hidden_dim) for _ in range(n_layers)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decoder: Projects hidden features back to Temperature space
|
# Decoder: Projects hidden features back to Temperature space
|
||||||
self.dec = nn.Sequential(
|
self.dec = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
||||||
@@ -185,26 +187,26 @@ class CorrectionNet(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, edge_index, edge_attr):
|
def forward(self, x, edge_index, edge_attr):
|
||||||
# 1. Global Residual Connection setup
|
# 1. Global Residual Connection setup
|
||||||
# We save the input to add it back at the very end.
|
# We save the input to add it back at the very end.
|
||||||
# The network learns the correction (Delta T), not the absolute T.
|
# The network learns the correction (Delta T), not the absolute T.
|
||||||
x_input = x
|
x_input = x
|
||||||
|
|
||||||
# 2. Encode
|
# 2. Encode
|
||||||
h = self.enc(x) * torch.exp(self.scale_x)
|
h = self.enc(x) * torch.exp(self.scale_x)
|
||||||
|
|
||||||
# Scale edge attributes (learnable gating of physical conductivity)
|
# Scale edge attributes (learnable gating of physical conductivity)
|
||||||
w = edge_attr * torch.exp(self.scale_edge_attr)
|
w = edge_attr * torch.exp(self.scale_edge_attr)
|
||||||
|
|
||||||
# 4. Message Passing (Diffusion Steps)
|
# 4. Message Passing (Diffusion Steps)
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
# h is updated internally via residual connection in DiffusionLayer
|
# h is updated internally via residual connection in DiffusionLayer
|
||||||
h = layer(h, edge_index, w)
|
h = layer(h, edge_index, w)
|
||||||
h = self.func(h)
|
h = self.func(h)
|
||||||
|
|
||||||
# 5. Decode
|
# 5. Decode
|
||||||
delta_x = self.dec(h)
|
delta_x = self.dec(h)
|
||||||
|
|
||||||
# 6. Final Update (Explicit Euler Step)
|
# 6. Final Update (Explicit Euler Step)
|
||||||
# T_new = T_old + Correction
|
# T_new = T_old + Correction
|
||||||
# return x_input + delta_x
|
# return x_input + delta_x
|
||||||
return delta_x
|
return delta_x
|
||||||
|
|||||||
Reference in New Issue
Block a user