add POD-NN model

This commit is contained in:
Filippo Olivo
2025-09-19 12:06:21 +02:00
parent 4c2f30e676
commit 2c5fa4259e
4 changed files with 309 additions and 1 deletions

22
experiments/pod_nn.yaml Normal file
View File

@@ -0,0 +1,22 @@
model:
model_class: model.pod_nn.PODNN
pod_rank: 128
layers: [16, 64, 128, 128, 64, 16]
optimizer:
optimizer_class: torch.optim.AdamW
lr: 0.001
trainer:
max_epochs: 10000
batch_size: 128
train_size: 0.9
val_size: 0.1
accelerator: 'cuda'
devices: 1
log_every_n_steps: 0
patience: 100
log_dir: logs/
name: pod_nn
# version: null

22
model/pod_nn.py Normal file
View File

@@ -0,0 +1,22 @@
import torch
from pina.model.block import PODBlock
from pina.model import FeedForward
class PODNN(torch.nn.Module):
def __init__(self, pod_rank, layers, func=torch.nn.Softplus):
super().__init__()
self.pod = PODBlock(pod_rank, scale_coefficients=False)
self.nn = FeedForward(
input_dimensions=3,
output_dimensions=pod_rank,
layers=layers,
func=func,
)
def forward(self, p):
coefficients = self.nn(p)
return self.pod.expand(coefficients)
def fit_pod(self, x):
self.pod.fit(x)

264
run.py Normal file
View File

@@ -0,0 +1,264 @@
import os
import yaml
import importlib
from pina import Trainer
import torch
import numpy as np
from pina.problem.zoo import SupervisedProblem
from pina.solver import ReducedOrderModelSolver
from pina.solver import SupervisedSolver
from pina.optim import TorchOptimizer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import TensorBoardLogger
from copy import deepcopy
from datasets import load_dataset
from torch.utils.data import random_split
from matplotlib import pyplot as plt
from matplotlib.tri import Triangulation
def compute_error(u_true, u_pred):
"""Compute the L2 error between true and predicted solutions."""
return np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true)
def argparse():
"""
Parse command line arguments for training configuration.
"""
import argparse
parser = argparse.ArgumentParser(description="Train a model with specified "
"parameters.")
parser.add_argument('--config', type=str, required=True,
help='Path to the configuration YAML file.')
return parser.parse_args()
def load_config(config_file):
"""
Configure the training parameters.
:param str config_file: Path to the configuration file.
:return: Configuration dictionary.
:rtype: dict
"""
with open(config_file, "r") as f:
config = yaml.safe_load(f)
return config
def load_model(model_args):
"""
Load the model class and instantiate it with the provided arguments.
:param dict model_args: Arguments for the model class.
:return: An instance of the model class.
:rtype: torch.nn.Module
"""
model_class = model_args.pop("model_class", "")
module_path, class_name = model_class.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
model = cls(**model_args)
return model
def load_data():
"""
Load the dataset from the specified path and preprocess it.
:param dict data_args: Arguments for loading the dataset.
:return: Training and testing datasets, normalizers, and points.
:rtype: tuple
- u_train (torch.Tensor): Training simulations.
- p_train (torch.Tensor): Training parameters.
"""
snapshots = load_dataset("SISSAmathLab/thermal-conduction", name="snapshots")["default"]
geom = load_dataset("SISSAmathLab/thermal-conduction", name="geometry")["default"]
points = torch.tensor(np.array(geom["points"]), dtype=torch.float32)
temperature = torch.tensor(np.array(snapshots["temperature"]), dtype=torch.float32)
params = torch.tensor(np.array(snapshots["params"]), dtype=torch.float32)
train_size = int(0.8 * len(temperature))
temp_train, temp_test = temperature[:train_size], temperature[train_size:]
params_train, params_test = params[:train_size], params[train_size:]
return temp_train, params_train, temp_test, params_test, points
def load_trainer(trainer_args, solver):
"""
Load and configure the Trainer for training.
:param dict trainer_args: Arguments for the Trainer.
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
instance to be used by the Trainer.
:return: Configured Trainer instance.
:rtype: Trainer
"""
patience = trainer_args.pop("patience", 100)
es = EarlyStopping(
monitor='val_loss',
patience=patience,
mode='min',
verbose=True
)
checkpoint = ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=1,
filename='best_model',
save_weights_only=True
)
logger = TensorBoardLogger(
save_dir=trainer_args.pop('log_dir', 'logs'),
name=trainer_args.pop('name'),
version=f"{trainer_args.pop('version'):03d}" if trainer_args.get('version', None) is not None else None
)
trainer_args['callbacks'] = [es, checkpoint]
trainer_args['solver'] = solver
trainer_args['logger'] = logger
trainer = Trainer(**trainer_args)
return trainer
def load_optimizer(optim_args):
"""
Load the optimizer class and instantiate it with the provided arguments.
:param dict optim_args: Arguments for the optimizer class.
:return: An instance of the TorchOptimizer class.
:rtype: TorchOptimizer
"""
print("Loading optimizer with args:", optim_args)
optim_class = optim_args.pop("optimizer_class", "")
module_path, class_name = optim_class.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
return TorchOptimizer(
cls,
**optim_args
)
def train(trainer):
"""
Train the model using the provided Trainer instance.
:param Trainer trainer: The Trainer instance configured for training.
"""
trainer.train()
def save_model(solver, trainer, problem, model, int_net):
"""
Save the trained model and its components to disk.
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
instance containing the trained model.
:param Trainer trainer: The Trainer instance used for training.
:param ~pina.problem.zoo.SupervisedProblem problem: The problem instance
associated with the solver.
:param torch.nn.Module model: The trained model to be saved.
:param torch.nn.Module int_net: The interpolation network, if used.
"""
model_path = trainer.logger.log_dir.replace("logs", "models")
os.makedirs(model_path, exist_ok=True)
if int_net is None:
solver = SupervisedSolver.load_from_checkpoint(
os.path.join(trainer.logger.log_dir, 'checkpoints',
'best_model.ckpt'),
problem=problem,
model=model,
use_lt=False)
model = solver.model.cpu()
model.eval()
torch.save(model.state_dict(), os.path.join(model_path, 'model.pth'))
if hasattr(model, 'pod'):
torch.save(model.pod.basis, os.path.join(model_path,
'pod_basis.pth')
)
return model
else:
solver = ReducedOrderModelSolver.load_from_checkpoint(
os.path.join(trainer.logger.log_dir, 'checkpoints',
'best_model.ckpt'),
problem=problem,
interpolation_network=int_net,
reduction_network=model
)
int_net = solver.model["interpolation_network"].cpu()
torch.save(int_net.state_dict(), os.path.join(model_path,
'interpolation_network.pth')
)
model = solver.model["reduction_network"].cpu()
torch.save(model.state_dict(), os.path.join(model_path,
'reduction_network.pth')
)
if hasattr(model, 'pod'):
torch.save(model.pod.basis, os.path.join(model_path,
'pod_basis.pth')
)
def main():
seed_everything(1999, workers=True)
args = argparse()
config = load_config(args.config)
config_ = deepcopy(config)
model_args = config.get("model", {})
model = load_model(model_args)
if "interpolation" in config:
model_args = config["interpolation"]
int_net = load_model(model_args)
else:
int_net = None
temperature, params, _, _, points = load_data()
problem = SupervisedProblem(output_=temperature, input_=params)
optimizer = load_optimizer(config.get("optimizer", {}))
if int_net is None:
if hasattr(model, 'fit_pod'):
model.fit_pod(temperature)
solver = SupervisedSolver(
problem=problem,
model=model,
optimizer=optimizer,
use_lt=False
)
else:
if "pod" in config_["model"]["model_class"]:
model.fit(temperature)
solver = ReducedOrderModelSolver(
problem= problem,
reduction_network=model,
interpolation_network=int_net,
optimizer=optimizer,
)
trainer_args = config.get("trainer", {})
trainer = load_trainer(trainer_args, solver)
train(trainer)
model_test = save_model(solver, trainer, problem, model, int_net)
model_test.eval()
temp_pred = model_test(params).detach().numpy()
temp_true = temperature.detach().numpy()
error = compute_error(temp_true, temp_pred)
print(points.shape)
tria = Triangulation(points[0,:,0], points[0,:,1])
levels_main = torch.linspace(0, temp_true[0].max().item(), steps=100)
abs_error = np.abs(temp_true[0]- temp_pred[0])
levels_diff = torch.linspace(0, abs_error.max().item(), steps=100)
fig, axs = plt.subplots(1,3, figsize=(15,5))
im = axs[0].tricontourf(tria, temp_true[0], levels=levels_main)
axs[0].set_title('True Temperature')
fig.colorbar(im, ax=axs[0])
im = axs[1].tricontourf(tria, temp_pred[0], levels=levels_main)
axs[1].set_title('Predicted Temperature')
fig.colorbar(im, ax=axs[1])
im = axs[2].tricontourf(tria, np.abs(temp_true[0]- temp_pred[0]), levels=levels_diff)
axs[2].set_title('Absolute Error')
fig.colorbar(im, ax=axs[2])
plt.savefig('temperature_comparison.png')
print(f"L2 error on training set: {error:.6f}")
if __name__ == "__main__":
main()