add POD-NN model
This commit is contained in:
22
experiments/pod_nn.yaml
Normal file
22
experiments/pod_nn.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
model:
|
||||||
|
model_class: model.pod_nn.PODNN
|
||||||
|
pod_rank: 128
|
||||||
|
layers: [16, 64, 128, 128, 64, 16]
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
optimizer_class: torch.optim.AdamW
|
||||||
|
lr: 0.001
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 10000
|
||||||
|
batch_size: 128
|
||||||
|
train_size: 0.9
|
||||||
|
val_size: 0.1
|
||||||
|
accelerator: 'cuda'
|
||||||
|
devices: 1
|
||||||
|
log_every_n_steps: 0
|
||||||
|
patience: 100
|
||||||
|
log_dir: logs/
|
||||||
|
name: pod_nn
|
||||||
|
# version: null
|
||||||
|
|
||||||
22
model/pod_nn.py
Normal file
22
model/pod_nn.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
from pina.model.block import PODBlock
|
||||||
|
from pina.model import FeedForward
|
||||||
|
|
||||||
|
class PODNN(torch.nn.Module):
|
||||||
|
def __init__(self, pod_rank, layers, func=torch.nn.Softplus):
|
||||||
|
super().__init__()
|
||||||
|
self.pod = PODBlock(pod_rank, scale_coefficients=False)
|
||||||
|
self.nn = FeedForward(
|
||||||
|
input_dimensions=3,
|
||||||
|
output_dimensions=pod_rank,
|
||||||
|
layers=layers,
|
||||||
|
func=func,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, p):
|
||||||
|
coefficients = self.nn(p)
|
||||||
|
return self.pod.expand(coefficients)
|
||||||
|
|
||||||
|
def fit_pod(self, x):
|
||||||
|
self.pod.fit(x)
|
||||||
264
run.py
Normal file
264
run.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import importlib
|
||||||
|
from pina import Trainer
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from pina.problem.zoo import SupervisedProblem
|
||||||
|
from pina.solver import ReducedOrderModelSolver
|
||||||
|
from pina.solver import SupervisedSolver
|
||||||
|
from pina.optim import TorchOptimizer
|
||||||
|
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
from lightning.pytorch import seed_everything
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
from copy import deepcopy
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import random_split
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from matplotlib.tri import Triangulation
|
||||||
|
|
||||||
|
|
||||||
|
def compute_error(u_true, u_pred):
|
||||||
|
"""Compute the L2 error between true and predicted solutions."""
|
||||||
|
return np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true)
|
||||||
|
|
||||||
|
def argparse():
|
||||||
|
"""
|
||||||
|
Parse command line arguments for training configuration.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description="Train a model with specified "
|
||||||
|
"parameters.")
|
||||||
|
parser.add_argument('--config', type=str, required=True,
|
||||||
|
help='Path to the configuration YAML file.')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def load_config(config_file):
|
||||||
|
"""
|
||||||
|
Configure the training parameters.
|
||||||
|
|
||||||
|
:param str config_file: Path to the configuration file.
|
||||||
|
:return: Configuration dictionary.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def load_model(model_args):
|
||||||
|
"""
|
||||||
|
Load the model class and instantiate it with the provided arguments.
|
||||||
|
|
||||||
|
:param dict model_args: Arguments for the model class.
|
||||||
|
:return: An instance of the model class.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
|
"""
|
||||||
|
model_class = model_args.pop("model_class", "")
|
||||||
|
module_path, class_name = model_class.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
model = cls(**model_args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_data():
|
||||||
|
"""
|
||||||
|
Load the dataset from the specified path and preprocess it.
|
||||||
|
|
||||||
|
:param dict data_args: Arguments for loading the dataset.
|
||||||
|
:return: Training and testing datasets, normalizers, and points.
|
||||||
|
:rtype: tuple
|
||||||
|
- u_train (torch.Tensor): Training simulations.
|
||||||
|
- p_train (torch.Tensor): Training parameters.
|
||||||
|
"""
|
||||||
|
snapshots = load_dataset("SISSAmathLab/thermal-conduction", name="snapshots")["default"]
|
||||||
|
geom = load_dataset("SISSAmathLab/thermal-conduction", name="geometry")["default"]
|
||||||
|
points = torch.tensor(np.array(geom["points"]), dtype=torch.float32)
|
||||||
|
|
||||||
|
temperature = torch.tensor(np.array(snapshots["temperature"]), dtype=torch.float32)
|
||||||
|
params = torch.tensor(np.array(snapshots["params"]), dtype=torch.float32)
|
||||||
|
|
||||||
|
train_size = int(0.8 * len(temperature))
|
||||||
|
temp_train, temp_test = temperature[:train_size], temperature[train_size:]
|
||||||
|
params_train, params_test = params[:train_size], params[train_size:]
|
||||||
|
|
||||||
|
return temp_train, params_train, temp_test, params_test, points
|
||||||
|
|
||||||
|
|
||||||
|
def load_trainer(trainer_args, solver):
|
||||||
|
"""
|
||||||
|
Load and configure the Trainer for training.
|
||||||
|
|
||||||
|
:param dict trainer_args: Arguments for the Trainer.
|
||||||
|
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
|
||||||
|
instance to be used by the Trainer.
|
||||||
|
:return: Configured Trainer instance.
|
||||||
|
:rtype: Trainer
|
||||||
|
"""
|
||||||
|
patience = trainer_args.pop("patience", 100)
|
||||||
|
es = EarlyStopping(
|
||||||
|
monitor='val_loss',
|
||||||
|
patience=patience,
|
||||||
|
mode='min',
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint = ModelCheckpoint(
|
||||||
|
monitor='val_loss',
|
||||||
|
mode='min',
|
||||||
|
save_top_k=1,
|
||||||
|
filename='best_model',
|
||||||
|
save_weights_only=True
|
||||||
|
)
|
||||||
|
logger = TensorBoardLogger(
|
||||||
|
save_dir=trainer_args.pop('log_dir', 'logs'),
|
||||||
|
name=trainer_args.pop('name'),
|
||||||
|
version=f"{trainer_args.pop('version'):03d}" if trainer_args.get('version', None) is not None else None
|
||||||
|
)
|
||||||
|
trainer_args['callbacks'] = [es, checkpoint]
|
||||||
|
trainer_args['solver'] = solver
|
||||||
|
trainer_args['logger'] = logger
|
||||||
|
|
||||||
|
trainer = Trainer(**trainer_args)
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
def load_optimizer(optim_args):
|
||||||
|
"""
|
||||||
|
Load the optimizer class and instantiate it with the provided arguments.
|
||||||
|
|
||||||
|
:param dict optim_args: Arguments for the optimizer class.
|
||||||
|
:return: An instance of the TorchOptimizer class.
|
||||||
|
:rtype: TorchOptimizer
|
||||||
|
"""
|
||||||
|
print("Loading optimizer with args:", optim_args)
|
||||||
|
optim_class = optim_args.pop("optimizer_class", "")
|
||||||
|
module_path, class_name = optim_class.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
return TorchOptimizer(
|
||||||
|
cls,
|
||||||
|
**optim_args
|
||||||
|
)
|
||||||
|
|
||||||
|
def train(trainer):
|
||||||
|
"""
|
||||||
|
Train the model using the provided Trainer instance.
|
||||||
|
|
||||||
|
:param Trainer trainer: The Trainer instance configured for training.
|
||||||
|
"""
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
def save_model(solver, trainer, problem, model, int_net):
|
||||||
|
"""
|
||||||
|
Save the trained model and its components to disk.
|
||||||
|
|
||||||
|
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
|
||||||
|
instance containing the trained model.
|
||||||
|
:param Trainer trainer: The Trainer instance used for training.
|
||||||
|
:param ~pina.problem.zoo.SupervisedProblem problem: The problem instance
|
||||||
|
associated with the solver.
|
||||||
|
:param torch.nn.Module model: The trained model to be saved.
|
||||||
|
:param torch.nn.Module int_net: The interpolation network, if used.
|
||||||
|
"""
|
||||||
|
model_path = trainer.logger.log_dir.replace("logs", "models")
|
||||||
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
if int_net is None:
|
||||||
|
solver = SupervisedSolver.load_from_checkpoint(
|
||||||
|
os.path.join(trainer.logger.log_dir, 'checkpoints',
|
||||||
|
'best_model.ckpt'),
|
||||||
|
problem=problem,
|
||||||
|
model=model,
|
||||||
|
use_lt=False)
|
||||||
|
model = solver.model.cpu()
|
||||||
|
model.eval()
|
||||||
|
torch.save(model.state_dict(), os.path.join(model_path, 'model.pth'))
|
||||||
|
if hasattr(model, 'pod'):
|
||||||
|
torch.save(model.pod.basis, os.path.join(model_path,
|
||||||
|
'pod_basis.pth')
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
solver = ReducedOrderModelSolver.load_from_checkpoint(
|
||||||
|
os.path.join(trainer.logger.log_dir, 'checkpoints',
|
||||||
|
'best_model.ckpt'),
|
||||||
|
problem=problem,
|
||||||
|
interpolation_network=int_net,
|
||||||
|
reduction_network=model
|
||||||
|
)
|
||||||
|
int_net = solver.model["interpolation_network"].cpu()
|
||||||
|
torch.save(int_net.state_dict(), os.path.join(model_path,
|
||||||
|
'interpolation_network.pth')
|
||||||
|
)
|
||||||
|
model = solver.model["reduction_network"].cpu()
|
||||||
|
torch.save(model.state_dict(), os.path.join(model_path,
|
||||||
|
'reduction_network.pth')
|
||||||
|
)
|
||||||
|
if hasattr(model, 'pod'):
|
||||||
|
torch.save(model.pod.basis, os.path.join(model_path,
|
||||||
|
'pod_basis.pth')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
seed_everything(1999, workers=True)
|
||||||
|
args = argparse()
|
||||||
|
config = load_config(args.config)
|
||||||
|
config_ = deepcopy(config)
|
||||||
|
model_args = config.get("model", {})
|
||||||
|
model = load_model(model_args)
|
||||||
|
|
||||||
|
if "interpolation" in config:
|
||||||
|
model_args = config["interpolation"]
|
||||||
|
int_net = load_model(model_args)
|
||||||
|
else:
|
||||||
|
int_net = None
|
||||||
|
|
||||||
|
temperature, params, _, _, points = load_data()
|
||||||
|
problem = SupervisedProblem(output_=temperature, input_=params)
|
||||||
|
optimizer = load_optimizer(config.get("optimizer", {}))
|
||||||
|
if int_net is None:
|
||||||
|
if hasattr(model, 'fit_pod'):
|
||||||
|
model.fit_pod(temperature)
|
||||||
|
solver = SupervisedSolver(
|
||||||
|
problem=problem,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
use_lt=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "pod" in config_["model"]["model_class"]:
|
||||||
|
model.fit(temperature)
|
||||||
|
solver = ReducedOrderModelSolver(
|
||||||
|
problem= problem,
|
||||||
|
reduction_network=model,
|
||||||
|
interpolation_network=int_net,
|
||||||
|
optimizer=optimizer,
|
||||||
|
)
|
||||||
|
trainer_args = config.get("trainer", {})
|
||||||
|
trainer = load_trainer(trainer_args, solver)
|
||||||
|
train(trainer)
|
||||||
|
model_test = save_model(solver, trainer, problem, model, int_net)
|
||||||
|
model_test.eval()
|
||||||
|
temp_pred = model_test(params).detach().numpy()
|
||||||
|
temp_true = temperature.detach().numpy()
|
||||||
|
error = compute_error(temp_true, temp_pred)
|
||||||
|
print(points.shape)
|
||||||
|
tria = Triangulation(points[0,:,0], points[0,:,1])
|
||||||
|
|
||||||
|
levels_main = torch.linspace(0, temp_true[0].max().item(), steps=100)
|
||||||
|
abs_error = np.abs(temp_true[0]- temp_pred[0])
|
||||||
|
levels_diff = torch.linspace(0, abs_error.max().item(), steps=100)
|
||||||
|
fig, axs = plt.subplots(1,3, figsize=(15,5))
|
||||||
|
im = axs[0].tricontourf(tria, temp_true[0], levels=levels_main)
|
||||||
|
axs[0].set_title('True Temperature')
|
||||||
|
fig.colorbar(im, ax=axs[0])
|
||||||
|
im = axs[1].tricontourf(tria, temp_pred[0], levels=levels_main)
|
||||||
|
axs[1].set_title('Predicted Temperature')
|
||||||
|
fig.colorbar(im, ax=axs[1])
|
||||||
|
im = axs[2].tricontourf(tria, np.abs(temp_true[0]- temp_pred[0]), levels=levels_diff)
|
||||||
|
axs[2].set_title('Absolute Error')
|
||||||
|
fig.colorbar(im, ax=axs[2])
|
||||||
|
plt.savefig('temperature_comparison.png')
|
||||||
|
print(f"L2 error on training set: {error:.6f}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Submodule thermal-conduction-dataset updated: 3d989dda38...bb924a56a8
Reference in New Issue
Block a user