From 2c5fa4259eb7e7e4d1f3db35155de88f7ef7ed7d Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Fri, 19 Sep 2025 12:06:21 +0200 Subject: [PATCH] add POD-NN model --- experiments/pod_nn.yaml | 22 ++++ model/pod_nn.py | 22 ++++ run.py | 264 +++++++++++++++++++++++++++++++++++++ thermal-conduction-dataset | 2 +- 4 files changed, 309 insertions(+), 1 deletion(-) create mode 100644 experiments/pod_nn.yaml create mode 100644 model/pod_nn.py create mode 100644 run.py diff --git a/experiments/pod_nn.yaml b/experiments/pod_nn.yaml new file mode 100644 index 0000000..774feb2 --- /dev/null +++ b/experiments/pod_nn.yaml @@ -0,0 +1,22 @@ +model: + model_class: model.pod_nn.PODNN + pod_rank: 128 + layers: [16, 64, 128, 128, 64, 16] + +optimizer: + optimizer_class: torch.optim.AdamW + lr: 0.001 + +trainer: + max_epochs: 10000 + batch_size: 128 + train_size: 0.9 + val_size: 0.1 + accelerator: 'cuda' + devices: 1 + log_every_n_steps: 0 + patience: 100 + log_dir: logs/ + name: pod_nn + # version: null + \ No newline at end of file diff --git a/model/pod_nn.py b/model/pod_nn.py new file mode 100644 index 0000000..e4af89e --- /dev/null +++ b/model/pod_nn.py @@ -0,0 +1,22 @@ +import torch +from pina.model.block import PODBlock +from pina.model import FeedForward + +class PODNN(torch.nn.Module): + def __init__(self, pod_rank, layers, func=torch.nn.Softplus): + super().__init__() + self.pod = PODBlock(pod_rank, scale_coefficients=False) + self.nn = FeedForward( + input_dimensions=3, + output_dimensions=pod_rank, + layers=layers, + func=func, + ) + + + def forward(self, p): + coefficients = self.nn(p) + return self.pod.expand(coefficients) + + def fit_pod(self, x): + self.pod.fit(x) \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..388788e --- /dev/null +++ b/run.py @@ -0,0 +1,264 @@ +import os +import yaml +import importlib +from pina import Trainer +import torch +import numpy as np +from pina.problem.zoo import SupervisedProblem +from pina.solver import ReducedOrderModelSolver +from pina.solver import SupervisedSolver +from pina.optim import TorchOptimizer +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch import seed_everything +from lightning.pytorch.loggers import TensorBoardLogger +from copy import deepcopy +from datasets import load_dataset +from torch.utils.data import random_split +from matplotlib import pyplot as plt +from matplotlib.tri import Triangulation + + +def compute_error(u_true, u_pred): + """Compute the L2 error between true and predicted solutions.""" + return np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true) + +def argparse(): + """ + Parse command line arguments for training configuration. + """ + import argparse + parser = argparse.ArgumentParser(description="Train a model with specified " + "parameters.") + parser.add_argument('--config', type=str, required=True, + help='Path to the configuration YAML file.') + return parser.parse_args() + +def load_config(config_file): + """ + Configure the training parameters. + + :param str config_file: Path to the configuration file. + :return: Configuration dictionary. + :rtype: dict + """ + with open(config_file, "r") as f: + config = yaml.safe_load(f) + return config + +def load_model(model_args): + """ + Load the model class and instantiate it with the provided arguments. + + :param dict model_args: Arguments for the model class. + :return: An instance of the model class. + :rtype: torch.nn.Module + """ + model_class = model_args.pop("model_class", "") + module_path, class_name = model_class.rsplit(".", 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + model = cls(**model_args) + return model + +def load_data(): + """ + Load the dataset from the specified path and preprocess it. + + :param dict data_args: Arguments for loading the dataset. + :return: Training and testing datasets, normalizers, and points. + :rtype: tuple + - u_train (torch.Tensor): Training simulations. + - p_train (torch.Tensor): Training parameters. + """ + snapshots = load_dataset("SISSAmathLab/thermal-conduction", name="snapshots")["default"] + geom = load_dataset("SISSAmathLab/thermal-conduction", name="geometry")["default"] + points = torch.tensor(np.array(geom["points"]), dtype=torch.float32) + + temperature = torch.tensor(np.array(snapshots["temperature"]), dtype=torch.float32) + params = torch.tensor(np.array(snapshots["params"]), dtype=torch.float32) + + train_size = int(0.8 * len(temperature)) + temp_train, temp_test = temperature[:train_size], temperature[train_size:] + params_train, params_test = params[:train_size], params[train_size:] + + return temp_train, params_train, temp_test, params_test, points + + +def load_trainer(trainer_args, solver): + """ + Load and configure the Trainer for training. + + :param dict trainer_args: Arguments for the Trainer. + :param ~pina.solver.solver_interface.SolverInterface solver: The solver + instance to be used by the Trainer. + :return: Configured Trainer instance. + :rtype: Trainer + """ + patience = trainer_args.pop("patience", 100) + es = EarlyStopping( + monitor='val_loss', + patience=patience, + mode='min', + verbose=True + ) + + checkpoint = ModelCheckpoint( + monitor='val_loss', + mode='min', + save_top_k=1, + filename='best_model', + save_weights_only=True + ) + logger = TensorBoardLogger( + save_dir=trainer_args.pop('log_dir', 'logs'), + name=trainer_args.pop('name'), + version=f"{trainer_args.pop('version'):03d}" if trainer_args.get('version', None) is not None else None + ) + trainer_args['callbacks'] = [es, checkpoint] + trainer_args['solver'] = solver + trainer_args['logger'] = logger + + trainer = Trainer(**trainer_args) + return trainer + +def load_optimizer(optim_args): + """ + Load the optimizer class and instantiate it with the provided arguments. + + :param dict optim_args: Arguments for the optimizer class. + :return: An instance of the TorchOptimizer class. + :rtype: TorchOptimizer + """ + print("Loading optimizer with args:", optim_args) + optim_class = optim_args.pop("optimizer_class", "") + module_path, class_name = optim_class.rsplit(".", 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + return TorchOptimizer( + cls, + **optim_args + ) + +def train(trainer): + """ + Train the model using the provided Trainer instance. + + :param Trainer trainer: The Trainer instance configured for training. + """ + trainer.train() + +def save_model(solver, trainer, problem, model, int_net): + """ + Save the trained model and its components to disk. + + :param ~pina.solver.solver_interface.SolverInterface solver: The solver + instance containing the trained model. + :param Trainer trainer: The Trainer instance used for training. + :param ~pina.problem.zoo.SupervisedProblem problem: The problem instance + associated with the solver. + :param torch.nn.Module model: The trained model to be saved. + :param torch.nn.Module int_net: The interpolation network, if used. + """ + model_path = trainer.logger.log_dir.replace("logs", "models") + os.makedirs(model_path, exist_ok=True) + if int_net is None: + solver = SupervisedSolver.load_from_checkpoint( + os.path.join(trainer.logger.log_dir, 'checkpoints', + 'best_model.ckpt'), + problem=problem, + model=model, + use_lt=False) + model = solver.model.cpu() + model.eval() + torch.save(model.state_dict(), os.path.join(model_path, 'model.pth')) + if hasattr(model, 'pod'): + torch.save(model.pod.basis, os.path.join(model_path, + 'pod_basis.pth') + ) + return model + else: + solver = ReducedOrderModelSolver.load_from_checkpoint( + os.path.join(trainer.logger.log_dir, 'checkpoints', + 'best_model.ckpt'), + problem=problem, + interpolation_network=int_net, + reduction_network=model + ) + int_net = solver.model["interpolation_network"].cpu() + torch.save(int_net.state_dict(), os.path.join(model_path, + 'interpolation_network.pth') + ) + model = solver.model["reduction_network"].cpu() + torch.save(model.state_dict(), os.path.join(model_path, + 'reduction_network.pth') + ) + if hasattr(model, 'pod'): + torch.save(model.pod.basis, os.path.join(model_path, + 'pod_basis.pth') + ) + + +def main(): + seed_everything(1999, workers=True) + args = argparse() + config = load_config(args.config) + config_ = deepcopy(config) + model_args = config.get("model", {}) + model = load_model(model_args) + + if "interpolation" in config: + model_args = config["interpolation"] + int_net = load_model(model_args) + else: + int_net = None + + temperature, params, _, _, points = load_data() + problem = SupervisedProblem(output_=temperature, input_=params) + optimizer = load_optimizer(config.get("optimizer", {})) + if int_net is None: + if hasattr(model, 'fit_pod'): + model.fit_pod(temperature) + solver = SupervisedSolver( + problem=problem, + model=model, + optimizer=optimizer, + use_lt=False + ) + else: + if "pod" in config_["model"]["model_class"]: + model.fit(temperature) + solver = ReducedOrderModelSolver( + problem= problem, + reduction_network=model, + interpolation_network=int_net, + optimizer=optimizer, + ) + trainer_args = config.get("trainer", {}) + trainer = load_trainer(trainer_args, solver) + train(trainer) + model_test = save_model(solver, trainer, problem, model, int_net) + model_test.eval() + temp_pred = model_test(params).detach().numpy() + temp_true = temperature.detach().numpy() + error = compute_error(temp_true, temp_pred) + print(points.shape) + tria = Triangulation(points[0,:,0], points[0,:,1]) + + levels_main = torch.linspace(0, temp_true[0].max().item(), steps=100) + abs_error = np.abs(temp_true[0]- temp_pred[0]) + levels_diff = torch.linspace(0, abs_error.max().item(), steps=100) + fig, axs = plt.subplots(1,3, figsize=(15,5)) + im = axs[0].tricontourf(tria, temp_true[0], levels=levels_main) + axs[0].set_title('True Temperature') + fig.colorbar(im, ax=axs[0]) + im = axs[1].tricontourf(tria, temp_pred[0], levels=levels_main) + axs[1].set_title('Predicted Temperature') + fig.colorbar(im, ax=axs[1]) + im = axs[2].tricontourf(tria, np.abs(temp_true[0]- temp_pred[0]), levels=levels_diff) + axs[2].set_title('Absolute Error') + fig.colorbar(im, ax=axs[2]) + plt.savefig('temperature_comparison.png') + print(f"L2 error on training set: {error:.6f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/thermal-conduction-dataset b/thermal-conduction-dataset index 3d989dd..bb924a5 160000 --- a/thermal-conduction-dataset +++ b/thermal-conduction-dataset @@ -1 +1 @@ -Subproject commit 3d989dda38f9d2a70b00ccba6815f7dff26a7dec +Subproject commit bb924a56a821e4522170edb6ce8886cfc6613470