From f3be9e99f893f87a8ec39fa6794e52cd30118576 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 25 Sep 2025 14:44:39 +0200 Subject: [PATCH] implement first GNO --- ThermalSolver/data_module.py | 2 +- ThermalSolver/model/basic_gno.py | 25 +++ ThermalSolver/model/local_gno.py | 74 ++++----- ThermalSolver/module.py | 18 +- run.py | 276 ++----------------------------- 5 files changed, 89 insertions(+), 306 deletions(-) create mode 100644 ThermalSolver/model/basic_gno.py diff --git a/ThermalSolver/data_module.py b/ThermalSolver/data_module.py index 1b74f16..905575a 100644 --- a/ThermalSolver/data_module.py +++ b/ThermalSolver/data_module.py @@ -64,7 +64,7 @@ class GraphDataModule(LightningDataModule): edge_attr = torch.cat( [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 ) - + return Data( x=boundary_vales.unsqueeze(-1), c=conductivity.unsqueeze(-1), diff --git a/ThermalSolver/model/basic_gno.py b/ThermalSolver/model/basic_gno.py new file mode 100644 index 0000000..bb76f0b --- /dev/null +++ b/ThermalSolver/model/basic_gno.py @@ -0,0 +1,25 @@ +from pina.model import GraphNeuralOperator +import torch +from torch_geometric.data import Data + + +class GNO(torch.nn.Module): + def __init__( + self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1 + ): + super().__init__() + + lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden) + self.gno = GraphNeuralOperator( + lifting_operator=lifting_operator, + projection_operator=torch.nn.Linear(hidden, out_ch), + edge_features=edge_ch, + n_layers=layers, + internal_n_layers=2, + shared_weights=False, + ) + + def forward(self, x, c, edge_index, edge_attr): + x = torch.cat([x, c], dim=-1) + x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + return self.gno(x) diff --git a/ThermalSolver/model/local_gno.py b/ThermalSolver/model/local_gno.py index e9c575d..7bd58f1 100644 --- a/ThermalSolver/model/local_gno.py +++ b/ThermalSolver/model/local_gno.py @@ -8,19 +8,16 @@ class FiLM(nn.Module): def __init__(self, c_ch, h_ch): super().__init__() self.net = nn.Sequential( - nn.Linear(c_ch, 2*h_ch), - nn.SiLU(), - nn.Linear(2*h_ch, 2*h_ch) + nn.Linear(c_ch, 2 * h_ch), nn.SiLU(), nn.Linear(2 * h_ch, 2 * h_ch) ) # init to identity: gamma≈0 (so 1+gamma=1), beta=0 nn.init.zeros_(self.net[-1].weight) nn.init.zeros_(self.net[-1].bias) - self.norm = nn.LayerNorm(h_ch) def forward(self, h, c): gb = self.net(c) gamma, beta = gb.chunk(2, dim=-1) - return (1 + gamma) * self.norm(h) + beta + return (1 + gamma) * h + beta class ConditionalGNOBlock(MessagePassing): @@ -28,46 +25,35 @@ class ConditionalGNOBlock(MessagePassing): Message passing with FiLM applied to the MESSAGE m_ij, using edge context c_ij = (c_i + c_j)/2. """ - def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): + + def __init__(self, hidden_ch, edge_ch=0, aggr="add"): super().__init__(aggr=aggr, node_dim=0) - self.pre_norm = nn.LayerNorm(hidden_ch) - - # raw message builder - self.msg = nn.Sequential( - nn.Linear(2*hidden_ch + edge_ch, 2*hidden_ch), - nn.SiLU(), - nn.Linear(2*hidden_ch, hidden_ch) - ) - # FiLM over the message (per-edge) self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch) - - # node update with residual - self.update_mlp = nn.Sequential( - nn.Linear(2*hidden_ch, hidden_ch), + self.edge_attr_net = nn.Sequential( + nn.Linear(edge_ch, hidden_ch // 2), nn.SiLU(), - nn.Linear(hidden_ch, hidden_ch) + nn.Linear(hidden_ch // 2, hidden_ch), + ) + self.x_net = nn.Sequential( + nn.Linear(hidden_ch, hidden_ch * 2), + nn.SiLU(), + nn.Linear(hidden_ch * 2, hidden_ch), ) def forward(self, x, c, edge_index, edge_attr=None): - # pre-norm helps stability - x_in = x - x = self.pre_norm(x) - m = self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) - out = self.update_mlp(torch.cat([x_in, m], dim=-1)) - return x_in + out # residual + return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) - def message(self, x_i, x_j, c_i, c_j, edge_attr): + def update(self, aggr_out, x): + return self.x_net(x) + aggr_out + + def message(self, x_j, c_i, c_j, edge_attr): + # c_ij = (c_i + c_j)/2 + c_ij = 0.5 * (c_i + c_j) + m = self.film_msg(x_j, c_ij) if edge_attr is not None: - m_in = torch.cat([x_i, x_j, edge_attr], dim=-1) - else: - m_in = torch.cat([x_i, x_j], dim=-1) - - m_raw = self.msg(m_in) - - # edge conditioning: simple mean - c_ctx = 0.5 * (c_i + c_j) - m = self.film_msg(m_raw, c_ctx) + a_ij = self.edge_attr_net(edge_attr) + m = m * a_ij return m @@ -79,7 +65,10 @@ class GatingGNO(nn.Module): Out: y : [N, out_ch] """ - def __init__(self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1): + + def __init__( + self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1 + ): super().__init__() self.encoder_x = nn.Sequential( nn.Linear(x_ch_node, hidden // 2), @@ -92,12 +81,15 @@ class GatingGNO(nn.Module): nn.Linear(hidden // 2, hidden), ) self.blocks = nn.ModuleList( - [ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) for _ in range(layers)] + [ + ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) + for _ in range(layers) + ] ) self.dec = nn.Sequential( - nn.LayerNorm(hidden), + nn.Linear(hidden, hidden // 2), nn.SiLU(), - nn.Linear(hidden, out_ch) + nn.Linear(hidden // 2, out_ch), ) def forward(self, x, c, edge_index, edge_attr=None): @@ -105,4 +97,4 @@ class GatingGNO(nn.Module): c = self.encoder_c(c) # [N,H] for blk in self.blocks: x = blk(x, c, edge_index, edge_attr=edge_attr) - return self.dec(x) \ No newline at end of file + return self.dec(x) diff --git a/ThermalSolver/module.py b/ThermalSolver/module.py index 922d810..83de199 100644 --- a/ThermalSolver/module.py +++ b/ThermalSolver/module.py @@ -4,7 +4,12 @@ from torch_geometric.data import Batch class GraphSolver(LightningModule): - def __init__(self, model: torch.nn.Module, loss: torch.nn.Module = None, unrolling_steps: int = 10): + def __init__( + self, + model: torch.nn.Module, + loss: torch.nn.Module = None, + unrolling_steps: int = 10, + ): super().__init__() self.model = model self.loss = loss if loss is not None else torch.nn.MSELoss() @@ -18,7 +23,7 @@ class GraphSolver(LightningModule): edge_attr: torch.Tensor, ): return self.model(x, c, edge_index, edge_attr) - + def _compute_loss_train(self, x, x_prev, y): return self.loss(x, y) + self.loss(x, x_prev) @@ -27,7 +32,7 @@ class GraphSolver(LightningModule): def _preprocess_batch(self, batch: Batch): return batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr - + def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}_loss", @@ -41,13 +46,14 @@ class GraphSolver(LightningModule): def training_step(self, batch: Batch, _): x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) + loss = 0.0 for _ in range(self.unrolling_steps): x_prev = x.detach() x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr) - loss = self.loss(x, y) + loss += self.loss(x, y) self._log_loss(loss, batch, "train") return loss - + def validation_step(self, batch: Batch, _): x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) for _ in range(self.unrolling_steps): @@ -70,5 +76,5 @@ class GraphSolver(LightningModule): return loss def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=5e-3) + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer diff --git a/run.py b/run.py index 388788e..e9bc24c 100644 --- a/run.py +++ b/run.py @@ -1,264 +1,24 @@ -import os -import yaml -import importlib -from pina import Trainer -import torch -import numpy as np -from pina.problem.zoo import SupervisedProblem -from pina.solver import ReducedOrderModelSolver -from pina.solver import SupervisedSolver -from pina.optim import TorchOptimizer -from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -from lightning.pytorch import seed_everything -from lightning.pytorch.loggers import TensorBoardLogger -from copy import deepcopy -from datasets import load_dataset -from torch.utils.data import random_split -from matplotlib import pyplot as plt -from matplotlib.tri import Triangulation - - -def compute_error(u_true, u_pred): - """Compute the L2 error between true and predicted solutions.""" - return np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true) - -def argparse(): - """ - Parse command line arguments for training configuration. - """ - import argparse - parser = argparse.ArgumentParser(description="Train a model with specified " - "parameters.") - parser.add_argument('--config', type=str, required=True, - help='Path to the configuration YAML file.') - return parser.parse_args() - -def load_config(config_file): - """ - Configure the training parameters. - - :param str config_file: Path to the configuration file. - :return: Configuration dictionary. - :rtype: dict - """ - with open(config_file, "r") as f: - config = yaml.safe_load(f) - return config - -def load_model(model_args): - """ - Load the model class and instantiate it with the provided arguments. - - :param dict model_args: Arguments for the model class. - :return: An instance of the model class. - :rtype: torch.nn.Module - """ - model_class = model_args.pop("model_class", "") - module_path, class_name = model_class.rsplit(".", 1) - module = importlib.import_module(module_path) - cls = getattr(module, class_name) - model = cls(**model_args) - return model - -def load_data(): - """ - Load the dataset from the specified path and preprocess it. - - :param dict data_args: Arguments for loading the dataset. - :return: Training and testing datasets, normalizers, and points. - :rtype: tuple - - u_train (torch.Tensor): Training simulations. - - p_train (torch.Tensor): Training parameters. - """ - snapshots = load_dataset("SISSAmathLab/thermal-conduction", name="snapshots")["default"] - geom = load_dataset("SISSAmathLab/thermal-conduction", name="geometry")["default"] - points = torch.tensor(np.array(geom["points"]), dtype=torch.float32) - - temperature = torch.tensor(np.array(snapshots["temperature"]), dtype=torch.float32) - params = torch.tensor(np.array(snapshots["params"]), dtype=torch.float32) - - train_size = int(0.8 * len(temperature)) - temp_train, temp_test = temperature[:train_size], temperature[train_size:] - params_train, params_test = params[:train_size], params[train_size:] - - return temp_train, params_train, temp_test, params_test, points - - -def load_trainer(trainer_args, solver): - """ - Load and configure the Trainer for training. - - :param dict trainer_args: Arguments for the Trainer. - :param ~pina.solver.solver_interface.SolverInterface solver: The solver - instance to be used by the Trainer. - :return: Configured Trainer instance. - :rtype: Trainer - """ - patience = trainer_args.pop("patience", 100) - es = EarlyStopping( - monitor='val_loss', - patience=patience, - mode='min', - verbose=True - ) - - checkpoint = ModelCheckpoint( - monitor='val_loss', - mode='min', - save_top_k=1, - filename='best_model', - save_weights_only=True - ) - logger = TensorBoardLogger( - save_dir=trainer_args.pop('log_dir', 'logs'), - name=trainer_args.pop('name'), - version=f"{trainer_args.pop('version'):03d}" if trainer_args.get('version', None) is not None else None - ) - trainer_args['callbacks'] = [es, checkpoint] - trainer_args['solver'] = solver - trainer_args['logger'] = logger - - trainer = Trainer(**trainer_args) - return trainer - -def load_optimizer(optim_args): - """ - Load the optimizer class and instantiate it with the provided arguments. - - :param dict optim_args: Arguments for the optimizer class. - :return: An instance of the TorchOptimizer class. - :rtype: TorchOptimizer - """ - print("Loading optimizer with args:", optim_args) - optim_class = optim_args.pop("optimizer_class", "") - module_path, class_name = optim_class.rsplit(".", 1) - module = importlib.import_module(module_path) - cls = getattr(module, class_name) - return TorchOptimizer( - cls, - **optim_args - ) - -def train(trainer): - """ - Train the model using the provided Trainer instance. - - :param Trainer trainer: The Trainer instance configured for training. - """ - trainer.train() - -def save_model(solver, trainer, problem, model, int_net): - """ - Save the trained model and its components to disk. - - :param ~pina.solver.solver_interface.SolverInterface solver: The solver - instance containing the trained model. - :param Trainer trainer: The Trainer instance used for training. - :param ~pina.problem.zoo.SupervisedProblem problem: The problem instance - associated with the solver. - :param torch.nn.Module model: The trained model to be saved. - :param torch.nn.Module int_net: The interpolation network, if used. - """ - model_path = trainer.logger.log_dir.replace("logs", "models") - os.makedirs(model_path, exist_ok=True) - if int_net is None: - solver = SupervisedSolver.load_from_checkpoint( - os.path.join(trainer.logger.log_dir, 'checkpoints', - 'best_model.ckpt'), - problem=problem, - model=model, - use_lt=False) - model = solver.model.cpu() - model.eval() - torch.save(model.state_dict(), os.path.join(model_path, 'model.pth')) - if hasattr(model, 'pod'): - torch.save(model.pod.basis, os.path.join(model_path, - 'pod_basis.pth') - ) - return model - else: - solver = ReducedOrderModelSolver.load_from_checkpoint( - os.path.join(trainer.logger.log_dir, 'checkpoints', - 'best_model.ckpt'), - problem=problem, - interpolation_network=int_net, - reduction_network=model - ) - int_net = solver.model["interpolation_network"].cpu() - torch.save(int_net.state_dict(), os.path.join(model_path, - 'interpolation_network.pth') - ) - model = solver.model["reduction_network"].cpu() - torch.save(model.state_dict(), os.path.join(model_path, - 'reduction_network.pth') - ) - if hasattr(model, 'pod'): - torch.save(model.pod.basis, os.path.join(model_path, - 'pod_basis.pth') - ) +from lightning import Trainer +from ThermalSolver.module import GraphSolver +from ThermalSolver.data_module import GraphDataModule +from ThermalSolver.model.local_gno import GatingGNO +from ThermalSolver.model.basic_gno import GNO def main(): - seed_everything(1999, workers=True) - args = argparse() - config = load_config(args.config) - config_ = deepcopy(config) - model_args = config.get("model", {}) - model = load_model(model_args) - - if "interpolation" in config: - model_args = config["interpolation"] - int_net = load_model(model_args) - else: - int_net = None - - temperature, params, _, _, points = load_data() - problem = SupervisedProblem(output_=temperature, input_=params) - optimizer = load_optimizer(config.get("optimizer", {})) - if int_net is None: - if hasattr(model, 'fit_pod'): - model.fit_pod(temperature) - solver = SupervisedSolver( - problem=problem, - model=model, - optimizer=optimizer, - use_lt=False - ) - else: - if "pod" in config_["model"]["model_class"]: - model.fit(temperature) - solver = ReducedOrderModelSolver( - problem= problem, - reduction_network=model, - interpolation_network=int_net, - optimizer=optimizer, - ) - trainer_args = config.get("trainer", {}) - trainer = load_trainer(trainer_args, solver) - train(trainer) - model_test = save_model(solver, trainer, problem, model, int_net) - model_test.eval() - temp_pred = model_test(params).detach().numpy() - temp_true = temperature.detach().numpy() - error = compute_error(temp_true, temp_pred) - print(points.shape) - tria = Triangulation(points[0,:,0], points[0,:,1]) - - levels_main = torch.linspace(0, temp_true[0].max().item(), steps=100) - abs_error = np.abs(temp_true[0]- temp_pred[0]) - levels_diff = torch.linspace(0, abs_error.max().item(), steps=100) - fig, axs = plt.subplots(1,3, figsize=(15,5)) - im = axs[0].tricontourf(tria, temp_true[0], levels=levels_main) - axs[0].set_title('True Temperature') - fig.colorbar(im, ax=axs[0]) - im = axs[1].tricontourf(tria, temp_pred[0], levels=levels_main) - axs[1].set_title('Predicted Temperature') - fig.colorbar(im, ax=axs[1]) - im = axs[2].tricontourf(tria, np.abs(temp_true[0]- temp_pred[0]), levels=levels_diff) - axs[2].set_title('Absolute Error') - fig.colorbar(im, ax=axs[2]) - plt.savefig('temperature_comparison.png') - print(f"L2 error on training set: {error:.6f}") + trainer = Trainer(max_epochs=100, accelerator="cuda", devices=1) + data_module = GraphDataModule( + hf_repo="SISSAmathLab/thermal-conduction", + split_name="easy", + train_size=0.8, + val_size=0.1, + test_size=0.1, + batch_size=8, + ) + model = GatingGNO(x_ch_node=1, f_ch_node=1, hidden=16, layers=8, edge_ch=3, out_ch=1) + solver = GraphSolver(model) + trainer.fit(solver, datamodule=data_module) + print("Done!") if __name__ == "__main__": main() \ No newline at end of file