implement first GNO

This commit is contained in:
FilippoOlivo
2025-09-25 14:44:39 +02:00
parent d53b076ecc
commit f3be9e99f8
5 changed files with 89 additions and 306 deletions

View File

@@ -64,7 +64,7 @@ class GraphDataModule(LightningDataModule):
edge_attr = torch.cat( edge_attr = torch.cat(
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
) )
return Data( return Data(
x=boundary_vales.unsqueeze(-1), x=boundary_vales.unsqueeze(-1),
c=conductivity.unsqueeze(-1), c=conductivity.unsqueeze(-1),

View File

@@ -0,0 +1,25 @@
from pina.model import GraphNeuralOperator
import torch
from torch_geometric.data import Data
class GNO(torch.nn.Module):
def __init__(
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
):
super().__init__()
lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden)
self.gno = GraphNeuralOperator(
lifting_operator=lifting_operator,
projection_operator=torch.nn.Linear(hidden, out_ch),
edge_features=edge_ch,
n_layers=layers,
internal_n_layers=2,
shared_weights=False,
)
def forward(self, x, c, edge_index, edge_attr):
x = torch.cat([x, c], dim=-1)
x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return self.gno(x)

View File

@@ -8,19 +8,16 @@ class FiLM(nn.Module):
def __init__(self, c_ch, h_ch): def __init__(self, c_ch, h_ch):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(c_ch, 2*h_ch), nn.Linear(c_ch, 2 * h_ch), nn.SiLU(), nn.Linear(2 * h_ch, 2 * h_ch)
nn.SiLU(),
nn.Linear(2*h_ch, 2*h_ch)
) )
# init to identity: gamma≈0 (so 1+gamma=1), beta=0 # init to identity: gamma≈0 (so 1+gamma=1), beta=0
nn.init.zeros_(self.net[-1].weight) nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias) nn.init.zeros_(self.net[-1].bias)
self.norm = nn.LayerNorm(h_ch)
def forward(self, h, c): def forward(self, h, c):
gb = self.net(c) gb = self.net(c)
gamma, beta = gb.chunk(2, dim=-1) gamma, beta = gb.chunk(2, dim=-1)
return (1 + gamma) * self.norm(h) + beta return (1 + gamma) * h + beta
class ConditionalGNOBlock(MessagePassing): class ConditionalGNOBlock(MessagePassing):
@@ -28,46 +25,35 @@ class ConditionalGNOBlock(MessagePassing):
Message passing with FiLM applied to the MESSAGE m_ij, Message passing with FiLM applied to the MESSAGE m_ij,
using edge context c_ij = (c_i + c_j)/2. using edge context c_ij = (c_i + c_j)/2.
""" """
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
def __init__(self, hidden_ch, edge_ch=0, aggr="add"):
super().__init__(aggr=aggr, node_dim=0) super().__init__(aggr=aggr, node_dim=0)
self.pre_norm = nn.LayerNorm(hidden_ch)
# raw message builder
self.msg = nn.Sequential(
nn.Linear(2*hidden_ch + edge_ch, 2*hidden_ch),
nn.SiLU(),
nn.Linear(2*hidden_ch, hidden_ch)
)
# FiLM over the message (per-edge) # FiLM over the message (per-edge)
self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch) self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
self.edge_attr_net = nn.Sequential(
# node update with residual nn.Linear(edge_ch, hidden_ch // 2),
self.update_mlp = nn.Sequential(
nn.Linear(2*hidden_ch, hidden_ch),
nn.SiLU(), nn.SiLU(),
nn.Linear(hidden_ch, hidden_ch) nn.Linear(hidden_ch // 2, hidden_ch),
)
self.x_net = nn.Sequential(
nn.Linear(hidden_ch, hidden_ch * 2),
nn.SiLU(),
nn.Linear(hidden_ch * 2, hidden_ch),
) )
def forward(self, x, c, edge_index, edge_attr=None): def forward(self, x, c, edge_index, edge_attr=None):
# pre-norm helps stability return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
x_in = x
x = self.pre_norm(x)
m = self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
out = self.update_mlp(torch.cat([x_in, m], dim=-1))
return x_in + out # residual
def message(self, x_i, x_j, c_i, c_j, edge_attr): def update(self, aggr_out, x):
return self.x_net(x) + aggr_out
def message(self, x_j, c_i, c_j, edge_attr):
# c_ij = (c_i + c_j)/2
c_ij = 0.5 * (c_i + c_j)
m = self.film_msg(x_j, c_ij)
if edge_attr is not None: if edge_attr is not None:
m_in = torch.cat([x_i, x_j, edge_attr], dim=-1) a_ij = self.edge_attr_net(edge_attr)
else: m = m * a_ij
m_in = torch.cat([x_i, x_j], dim=-1)
m_raw = self.msg(m_in)
# edge conditioning: simple mean
c_ctx = 0.5 * (c_i + c_j)
m = self.film_msg(m_raw, c_ctx)
return m return m
@@ -79,7 +65,10 @@ class GatingGNO(nn.Module):
Out: Out:
y : [N, out_ch] y : [N, out_ch]
""" """
def __init__(self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1):
def __init__(
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
):
super().__init__() super().__init__()
self.encoder_x = nn.Sequential( self.encoder_x = nn.Sequential(
nn.Linear(x_ch_node, hidden // 2), nn.Linear(x_ch_node, hidden // 2),
@@ -92,12 +81,15 @@ class GatingGNO(nn.Module):
nn.Linear(hidden // 2, hidden), nn.Linear(hidden // 2, hidden),
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) for _ in range(layers)] [
ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch)
for _ in range(layers)
]
) )
self.dec = nn.Sequential( self.dec = nn.Sequential(
nn.LayerNorm(hidden), nn.Linear(hidden, hidden // 2),
nn.SiLU(), nn.SiLU(),
nn.Linear(hidden, out_ch) nn.Linear(hidden // 2, out_ch),
) )
def forward(self, x, c, edge_index, edge_attr=None): def forward(self, x, c, edge_index, edge_attr=None):
@@ -105,4 +97,4 @@ class GatingGNO(nn.Module):
c = self.encoder_c(c) # [N,H] c = self.encoder_c(c) # [N,H]
for blk in self.blocks: for blk in self.blocks:
x = blk(x, c, edge_index, edge_attr=edge_attr) x = blk(x, c, edge_index, edge_attr=edge_attr)
return self.dec(x) return self.dec(x)

View File

@@ -4,7 +4,12 @@ from torch_geometric.data import Batch
class GraphSolver(LightningModule): class GraphSolver(LightningModule):
def __init__(self, model: torch.nn.Module, loss: torch.nn.Module = None, unrolling_steps: int = 10): def __init__(
self,
model: torch.nn.Module,
loss: torch.nn.Module = None,
unrolling_steps: int = 10,
):
super().__init__() super().__init__()
self.model = model self.model = model
self.loss = loss if loss is not None else torch.nn.MSELoss() self.loss = loss if loss is not None else torch.nn.MSELoss()
@@ -18,7 +23,7 @@ class GraphSolver(LightningModule):
edge_attr: torch.Tensor, edge_attr: torch.Tensor,
): ):
return self.model(x, c, edge_index, edge_attr) return self.model(x, c, edge_index, edge_attr)
def _compute_loss_train(self, x, x_prev, y): def _compute_loss_train(self, x, x_prev, y):
return self.loss(x, y) + self.loss(x, x_prev) return self.loss(x, y) + self.loss(x, x_prev)
@@ -27,7 +32,7 @@ class GraphSolver(LightningModule):
def _preprocess_batch(self, batch: Batch): def _preprocess_batch(self, batch: Batch):
return batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr return batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr
def _log_loss(self, loss, batch, stage: str): def _log_loss(self, loss, batch, stage: str):
self.log( self.log(
f"{stage}_loss", f"{stage}_loss",
@@ -41,13 +46,14 @@ class GraphSolver(LightningModule):
def training_step(self, batch: Batch, _): def training_step(self, batch: Batch, _):
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
loss = 0.0
for _ in range(self.unrolling_steps): for _ in range(self.unrolling_steps):
x_prev = x.detach() x_prev = x.detach()
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr) x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
loss = self.loss(x, y) loss += self.loss(x, y)
self._log_loss(loss, batch, "train") self._log_loss(loss, batch, "train")
return loss return loss
def validation_step(self, batch: Batch, _): def validation_step(self, batch: Batch, _):
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
for _ in range(self.unrolling_steps): for _ in range(self.unrolling_steps):
@@ -70,5 +76,5 @@ class GraphSolver(LightningModule):
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=5e-3) optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer return optimizer

276
run.py
View File

@@ -1,264 +1,24 @@
import os from lightning import Trainer
import yaml from ThermalSolver.module import GraphSolver
import importlib from ThermalSolver.data_module import GraphDataModule
from pina import Trainer from ThermalSolver.model.local_gno import GatingGNO
import torch from ThermalSolver.model.basic_gno import GNO
import numpy as np
from pina.problem.zoo import SupervisedProblem
from pina.solver import ReducedOrderModelSolver
from pina.solver import SupervisedSolver
from pina.optim import TorchOptimizer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import TensorBoardLogger
from copy import deepcopy
from datasets import load_dataset
from torch.utils.data import random_split
from matplotlib import pyplot as plt
from matplotlib.tri import Triangulation
def compute_error(u_true, u_pred):
"""Compute the L2 error between true and predicted solutions."""
return np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true)
def argparse():
"""
Parse command line arguments for training configuration.
"""
import argparse
parser = argparse.ArgumentParser(description="Train a model with specified "
"parameters.")
parser.add_argument('--config', type=str, required=True,
help='Path to the configuration YAML file.')
return parser.parse_args()
def load_config(config_file):
"""
Configure the training parameters.
:param str config_file: Path to the configuration file.
:return: Configuration dictionary.
:rtype: dict
"""
with open(config_file, "r") as f:
config = yaml.safe_load(f)
return config
def load_model(model_args):
"""
Load the model class and instantiate it with the provided arguments.
:param dict model_args: Arguments for the model class.
:return: An instance of the model class.
:rtype: torch.nn.Module
"""
model_class = model_args.pop("model_class", "")
module_path, class_name = model_class.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
model = cls(**model_args)
return model
def load_data():
"""
Load the dataset from the specified path and preprocess it.
:param dict data_args: Arguments for loading the dataset.
:return: Training and testing datasets, normalizers, and points.
:rtype: tuple
- u_train (torch.Tensor): Training simulations.
- p_train (torch.Tensor): Training parameters.
"""
snapshots = load_dataset("SISSAmathLab/thermal-conduction", name="snapshots")["default"]
geom = load_dataset("SISSAmathLab/thermal-conduction", name="geometry")["default"]
points = torch.tensor(np.array(geom["points"]), dtype=torch.float32)
temperature = torch.tensor(np.array(snapshots["temperature"]), dtype=torch.float32)
params = torch.tensor(np.array(snapshots["params"]), dtype=torch.float32)
train_size = int(0.8 * len(temperature))
temp_train, temp_test = temperature[:train_size], temperature[train_size:]
params_train, params_test = params[:train_size], params[train_size:]
return temp_train, params_train, temp_test, params_test, points
def load_trainer(trainer_args, solver):
"""
Load and configure the Trainer for training.
:param dict trainer_args: Arguments for the Trainer.
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
instance to be used by the Trainer.
:return: Configured Trainer instance.
:rtype: Trainer
"""
patience = trainer_args.pop("patience", 100)
es = EarlyStopping(
monitor='val_loss',
patience=patience,
mode='min',
verbose=True
)
checkpoint = ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=1,
filename='best_model',
save_weights_only=True
)
logger = TensorBoardLogger(
save_dir=trainer_args.pop('log_dir', 'logs'),
name=trainer_args.pop('name'),
version=f"{trainer_args.pop('version'):03d}" if trainer_args.get('version', None) is not None else None
)
trainer_args['callbacks'] = [es, checkpoint]
trainer_args['solver'] = solver
trainer_args['logger'] = logger
trainer = Trainer(**trainer_args)
return trainer
def load_optimizer(optim_args):
"""
Load the optimizer class and instantiate it with the provided arguments.
:param dict optim_args: Arguments for the optimizer class.
:return: An instance of the TorchOptimizer class.
:rtype: TorchOptimizer
"""
print("Loading optimizer with args:", optim_args)
optim_class = optim_args.pop("optimizer_class", "")
module_path, class_name = optim_class.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
return TorchOptimizer(
cls,
**optim_args
)
def train(trainer):
"""
Train the model using the provided Trainer instance.
:param Trainer trainer: The Trainer instance configured for training.
"""
trainer.train()
def save_model(solver, trainer, problem, model, int_net):
"""
Save the trained model and its components to disk.
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
instance containing the trained model.
:param Trainer trainer: The Trainer instance used for training.
:param ~pina.problem.zoo.SupervisedProblem problem: The problem instance
associated with the solver.
:param torch.nn.Module model: The trained model to be saved.
:param torch.nn.Module int_net: The interpolation network, if used.
"""
model_path = trainer.logger.log_dir.replace("logs", "models")
os.makedirs(model_path, exist_ok=True)
if int_net is None:
solver = SupervisedSolver.load_from_checkpoint(
os.path.join(trainer.logger.log_dir, 'checkpoints',
'best_model.ckpt'),
problem=problem,
model=model,
use_lt=False)
model = solver.model.cpu()
model.eval()
torch.save(model.state_dict(), os.path.join(model_path, 'model.pth'))
if hasattr(model, 'pod'):
torch.save(model.pod.basis, os.path.join(model_path,
'pod_basis.pth')
)
return model
else:
solver = ReducedOrderModelSolver.load_from_checkpoint(
os.path.join(trainer.logger.log_dir, 'checkpoints',
'best_model.ckpt'),
problem=problem,
interpolation_network=int_net,
reduction_network=model
)
int_net = solver.model["interpolation_network"].cpu()
torch.save(int_net.state_dict(), os.path.join(model_path,
'interpolation_network.pth')
)
model = solver.model["reduction_network"].cpu()
torch.save(model.state_dict(), os.path.join(model_path,
'reduction_network.pth')
)
if hasattr(model, 'pod'):
torch.save(model.pod.basis, os.path.join(model_path,
'pod_basis.pth')
)
def main(): def main():
seed_everything(1999, workers=True) trainer = Trainer(max_epochs=100, accelerator="cuda", devices=1)
args = argparse() data_module = GraphDataModule(
config = load_config(args.config) hf_repo="SISSAmathLab/thermal-conduction",
config_ = deepcopy(config) split_name="easy",
model_args = config.get("model", {}) train_size=0.8,
model = load_model(model_args) val_size=0.1,
test_size=0.1,
if "interpolation" in config: batch_size=8,
model_args = config["interpolation"] )
int_net = load_model(model_args) model = GatingGNO(x_ch_node=1, f_ch_node=1, hidden=16, layers=8, edge_ch=3, out_ch=1)
else: solver = GraphSolver(model)
int_net = None trainer.fit(solver, datamodule=data_module)
print("Done!")
temperature, params, _, _, points = load_data()
problem = SupervisedProblem(output_=temperature, input_=params)
optimizer = load_optimizer(config.get("optimizer", {}))
if int_net is None:
if hasattr(model, 'fit_pod'):
model.fit_pod(temperature)
solver = SupervisedSolver(
problem=problem,
model=model,
optimizer=optimizer,
use_lt=False
)
else:
if "pod" in config_["model"]["model_class"]:
model.fit(temperature)
solver = ReducedOrderModelSolver(
problem= problem,
reduction_network=model,
interpolation_network=int_net,
optimizer=optimizer,
)
trainer_args = config.get("trainer", {})
trainer = load_trainer(trainer_args, solver)
train(trainer)
model_test = save_model(solver, trainer, problem, model, int_net)
model_test.eval()
temp_pred = model_test(params).detach().numpy()
temp_true = temperature.detach().numpy()
error = compute_error(temp_true, temp_pred)
print(points.shape)
tria = Triangulation(points[0,:,0], points[0,:,1])
levels_main = torch.linspace(0, temp_true[0].max().item(), steps=100)
abs_error = np.abs(temp_true[0]- temp_pred[0])
levels_diff = torch.linspace(0, abs_error.max().item(), steps=100)
fig, axs = plt.subplots(1,3, figsize=(15,5))
im = axs[0].tricontourf(tria, temp_true[0], levels=levels_main)
axs[0].set_title('True Temperature')
fig.colorbar(im, ax=axs[0])
im = axs[1].tricontourf(tria, temp_pred[0], levels=levels_main)
axs[1].set_title('Predicted Temperature')
fig.colorbar(im, ax=axs[1])
im = axs[2].tricontourf(tria, np.abs(temp_true[0]- temp_pred[0]), levels=levels_diff)
axs[2].set_title('Absolute Error')
fig.colorbar(im, ax=axs[2])
plt.savefig('temperature_comparison.png')
print(f"L2 error on training set: {error:.6f}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()