import os import yaml import importlib from pina import Trainer import torch import numpy as np from pina.problem.zoo import SupervisedProblem from pina.solver import ReducedOrderModelSolver from pina.solver import SupervisedSolver from pina.optim import TorchOptimizer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch import seed_everything from lightning.pytorch.loggers import TensorBoardLogger from copy import deepcopy from datasets import load_dataset from torch.utils.data import random_split from matplotlib import pyplot as plt from matplotlib.tri import Triangulation def compute_error(u_true, u_pred): """Compute the L2 error between true and predicted solutions.""" return np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true) def argparse(): """ Parse command line arguments for training configuration. """ import argparse parser = argparse.ArgumentParser(description="Train a model with specified " "parameters.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration YAML file.') return parser.parse_args() def load_config(config_file): """ Configure the training parameters. :param str config_file: Path to the configuration file. :return: Configuration dictionary. :rtype: dict """ with open(config_file, "r") as f: config = yaml.safe_load(f) return config def load_model(model_args): """ Load the model class and instantiate it with the provided arguments. :param dict model_args: Arguments for the model class. :return: An instance of the model class. :rtype: torch.nn.Module """ model_class = model_args.pop("model_class", "") module_path, class_name = model_class.rsplit(".", 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) model = cls(**model_args) return model def load_data(): """ Load the dataset from the specified path and preprocess it. :param dict data_args: Arguments for loading the dataset. :return: Training and testing datasets, normalizers, and points. :rtype: tuple - u_train (torch.Tensor): Training simulations. - p_train (torch.Tensor): Training parameters. """ snapshots = load_dataset("SISSAmathLab/thermal-conduction", name="snapshots")["default"] geom = load_dataset("SISSAmathLab/thermal-conduction", name="geometry")["default"] points = torch.tensor(np.array(geom["points"]), dtype=torch.float32) temperature = torch.tensor(np.array(snapshots["temperature"]), dtype=torch.float32) params = torch.tensor(np.array(snapshots["params"]), dtype=torch.float32) train_size = int(0.8 * len(temperature)) temp_train, temp_test = temperature[:train_size], temperature[train_size:] params_train, params_test = params[:train_size], params[train_size:] return temp_train, params_train, temp_test, params_test, points def load_trainer(trainer_args, solver): """ Load and configure the Trainer for training. :param dict trainer_args: Arguments for the Trainer. :param ~pina.solver.solver_interface.SolverInterface solver: The solver instance to be used by the Trainer. :return: Configured Trainer instance. :rtype: Trainer """ patience = trainer_args.pop("patience", 100) es = EarlyStopping( monitor='val_loss', patience=patience, mode='min', verbose=True ) checkpoint = ModelCheckpoint( monitor='val_loss', mode='min', save_top_k=1, filename='best_model', save_weights_only=True ) logger = TensorBoardLogger( save_dir=trainer_args.pop('log_dir', 'logs'), name=trainer_args.pop('name'), version=f"{trainer_args.pop('version'):03d}" if trainer_args.get('version', None) is not None else None ) trainer_args['callbacks'] = [es, checkpoint] trainer_args['solver'] = solver trainer_args['logger'] = logger trainer = Trainer(**trainer_args) return trainer def load_optimizer(optim_args): """ Load the optimizer class and instantiate it with the provided arguments. :param dict optim_args: Arguments for the optimizer class. :return: An instance of the TorchOptimizer class. :rtype: TorchOptimizer """ print("Loading optimizer with args:", optim_args) optim_class = optim_args.pop("optimizer_class", "") module_path, class_name = optim_class.rsplit(".", 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) return TorchOptimizer( cls, **optim_args ) def train(trainer): """ Train the model using the provided Trainer instance. :param Trainer trainer: The Trainer instance configured for training. """ trainer.train() def save_model(solver, trainer, problem, model, int_net): """ Save the trained model and its components to disk. :param ~pina.solver.solver_interface.SolverInterface solver: The solver instance containing the trained model. :param Trainer trainer: The Trainer instance used for training. :param ~pina.problem.zoo.SupervisedProblem problem: The problem instance associated with the solver. :param torch.nn.Module model: The trained model to be saved. :param torch.nn.Module int_net: The interpolation network, if used. """ model_path = trainer.logger.log_dir.replace("logs", "models") os.makedirs(model_path, exist_ok=True) if int_net is None: solver = SupervisedSolver.load_from_checkpoint( os.path.join(trainer.logger.log_dir, 'checkpoints', 'best_model.ckpt'), problem=problem, model=model, use_lt=False) model = solver.model.cpu() model.eval() torch.save(model.state_dict(), os.path.join(model_path, 'model.pth')) if hasattr(model, 'pod'): torch.save(model.pod.basis, os.path.join(model_path, 'pod_basis.pth') ) return model else: solver = ReducedOrderModelSolver.load_from_checkpoint( os.path.join(trainer.logger.log_dir, 'checkpoints', 'best_model.ckpt'), problem=problem, interpolation_network=int_net, reduction_network=model ) int_net = solver.model["interpolation_network"].cpu() torch.save(int_net.state_dict(), os.path.join(model_path, 'interpolation_network.pth') ) model = solver.model["reduction_network"].cpu() torch.save(model.state_dict(), os.path.join(model_path, 'reduction_network.pth') ) if hasattr(model, 'pod'): torch.save(model.pod.basis, os.path.join(model_path, 'pod_basis.pth') ) def main(): seed_everything(1999, workers=True) args = argparse() config = load_config(args.config) config_ = deepcopy(config) model_args = config.get("model", {}) model = load_model(model_args) if "interpolation" in config: model_args = config["interpolation"] int_net = load_model(model_args) else: int_net = None temperature, params, _, _, points = load_data() problem = SupervisedProblem(output_=temperature, input_=params) optimizer = load_optimizer(config.get("optimizer", {})) if int_net is None: if hasattr(model, 'fit_pod'): model.fit_pod(temperature) solver = SupervisedSolver( problem=problem, model=model, optimizer=optimizer, use_lt=False ) else: if "pod" in config_["model"]["model_class"]: model.fit(temperature) solver = ReducedOrderModelSolver( problem= problem, reduction_network=model, interpolation_network=int_net, optimizer=optimizer, ) trainer_args = config.get("trainer", {}) trainer = load_trainer(trainer_args, solver) train(trainer) model_test = save_model(solver, trainer, problem, model, int_net) model_test.eval() temp_pred = model_test(params).detach().numpy() temp_true = temperature.detach().numpy() error = compute_error(temp_true, temp_pred) print(points.shape) tria = Triangulation(points[0,:,0], points[0,:,1]) levels_main = torch.linspace(0, temp_true[0].max().item(), steps=100) abs_error = np.abs(temp_true[0]- temp_pred[0]) levels_diff = torch.linspace(0, abs_error.max().item(), steps=100) fig, axs = plt.subplots(1,3, figsize=(15,5)) im = axs[0].tricontourf(tria, temp_true[0], levels=levels_main) axs[0].set_title('True Temperature') fig.colorbar(im, ax=axs[0]) im = axs[1].tricontourf(tria, temp_pred[0], levels=levels_main) axs[1].set_title('Predicted Temperature') fig.colorbar(im, ax=axs[1]) im = axs[2].tricontourf(tria, np.abs(temp_true[0]- temp_pred[0]), levels=levels_diff) axs[2].set_title('Absolute Error') fig.colorbar(im, ax=axs[2]) plt.savefig('temperature_comparison.png') print(f"L2 error on training set: {error:.6f}") if __name__ == "__main__": main()