implement first GNO
This commit is contained in:
@@ -64,7 +64,7 @@ class GraphDataModule(LightningDataModule):
|
||||
edge_attr = torch.cat(
|
||||
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
||||
)
|
||||
|
||||
|
||||
return Data(
|
||||
x=boundary_vales.unsqueeze(-1),
|
||||
c=conductivity.unsqueeze(-1),
|
||||
|
||||
25
ThermalSolver/model/basic_gno.py
Normal file
25
ThermalSolver/model/basic_gno.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pina.model import GraphNeuralOperator
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
|
||||
|
||||
class GNO(torch.nn.Module):
|
||||
def __init__(
|
||||
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden)
|
||||
self.gno = GraphNeuralOperator(
|
||||
lifting_operator=lifting_operator,
|
||||
projection_operator=torch.nn.Linear(hidden, out_ch),
|
||||
edge_features=edge_ch,
|
||||
n_layers=layers,
|
||||
internal_n_layers=2,
|
||||
shared_weights=False,
|
||||
)
|
||||
|
||||
def forward(self, x, c, edge_index, edge_attr):
|
||||
x = torch.cat([x, c], dim=-1)
|
||||
x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
||||
return self.gno(x)
|
||||
@@ -8,19 +8,16 @@ class FiLM(nn.Module):
|
||||
def __init__(self, c_ch, h_ch):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(c_ch, 2*h_ch),
|
||||
nn.SiLU(),
|
||||
nn.Linear(2*h_ch, 2*h_ch)
|
||||
nn.Linear(c_ch, 2 * h_ch), nn.SiLU(), nn.Linear(2 * h_ch, 2 * h_ch)
|
||||
)
|
||||
# init to identity: gamma≈0 (so 1+gamma=1), beta=0
|
||||
nn.init.zeros_(self.net[-1].weight)
|
||||
nn.init.zeros_(self.net[-1].bias)
|
||||
self.norm = nn.LayerNorm(h_ch)
|
||||
|
||||
def forward(self, h, c):
|
||||
gb = self.net(c)
|
||||
gamma, beta = gb.chunk(2, dim=-1)
|
||||
return (1 + gamma) * self.norm(h) + beta
|
||||
return (1 + gamma) * h + beta
|
||||
|
||||
|
||||
class ConditionalGNOBlock(MessagePassing):
|
||||
@@ -28,46 +25,35 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
Message passing with FiLM applied to the MESSAGE m_ij,
|
||||
using edge context c_ij = (c_i + c_j)/2.
|
||||
"""
|
||||
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
|
||||
|
||||
def __init__(self, hidden_ch, edge_ch=0, aggr="add"):
|
||||
super().__init__(aggr=aggr, node_dim=0)
|
||||
self.pre_norm = nn.LayerNorm(hidden_ch)
|
||||
|
||||
# raw message builder
|
||||
self.msg = nn.Sequential(
|
||||
nn.Linear(2*hidden_ch + edge_ch, 2*hidden_ch),
|
||||
nn.SiLU(),
|
||||
nn.Linear(2*hidden_ch, hidden_ch)
|
||||
)
|
||||
|
||||
# FiLM over the message (per-edge)
|
||||
self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
|
||||
|
||||
# node update with residual
|
||||
self.update_mlp = nn.Sequential(
|
||||
nn.Linear(2*hidden_ch, hidden_ch),
|
||||
self.edge_attr_net = nn.Sequential(
|
||||
nn.Linear(edge_ch, hidden_ch // 2),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_ch, hidden_ch)
|
||||
nn.Linear(hidden_ch // 2, hidden_ch),
|
||||
)
|
||||
self.x_net = nn.Sequential(
|
||||
nn.Linear(hidden_ch, hidden_ch * 2),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_ch * 2, hidden_ch),
|
||||
)
|
||||
|
||||
def forward(self, x, c, edge_index, edge_attr=None):
|
||||
# pre-norm helps stability
|
||||
x_in = x
|
||||
x = self.pre_norm(x)
|
||||
m = self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
|
||||
out = self.update_mlp(torch.cat([x_in, m], dim=-1))
|
||||
return x_in + out # residual
|
||||
return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
|
||||
|
||||
def message(self, x_i, x_j, c_i, c_j, edge_attr):
|
||||
def update(self, aggr_out, x):
|
||||
return self.x_net(x) + aggr_out
|
||||
|
||||
def message(self, x_j, c_i, c_j, edge_attr):
|
||||
# c_ij = (c_i + c_j)/2
|
||||
c_ij = 0.5 * (c_i + c_j)
|
||||
m = self.film_msg(x_j, c_ij)
|
||||
if edge_attr is not None:
|
||||
m_in = torch.cat([x_i, x_j, edge_attr], dim=-1)
|
||||
else:
|
||||
m_in = torch.cat([x_i, x_j], dim=-1)
|
||||
|
||||
m_raw = self.msg(m_in)
|
||||
|
||||
# edge conditioning: simple mean
|
||||
c_ctx = 0.5 * (c_i + c_j)
|
||||
m = self.film_msg(m_raw, c_ctx)
|
||||
a_ij = self.edge_attr_net(edge_attr)
|
||||
m = m * a_ij
|
||||
return m
|
||||
|
||||
|
||||
@@ -79,7 +65,10 @@ class GatingGNO(nn.Module):
|
||||
Out:
|
||||
y : [N, out_ch]
|
||||
"""
|
||||
def __init__(self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1):
|
||||
|
||||
def __init__(
|
||||
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder_x = nn.Sequential(
|
||||
nn.Linear(x_ch_node, hidden // 2),
|
||||
@@ -92,12 +81,15 @@ class GatingGNO(nn.Module):
|
||||
nn.Linear(hidden // 2, hidden),
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) for _ in range(layers)]
|
||||
[
|
||||
ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
self.dec = nn.Sequential(
|
||||
nn.LayerNorm(hidden),
|
||||
nn.Linear(hidden, hidden // 2),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden, out_ch)
|
||||
nn.Linear(hidden // 2, out_ch),
|
||||
)
|
||||
|
||||
def forward(self, x, c, edge_index, edge_attr=None):
|
||||
@@ -105,4 +97,4 @@ class GatingGNO(nn.Module):
|
||||
c = self.encoder_c(c) # [N,H]
|
||||
for blk in self.blocks:
|
||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||
return self.dec(x)
|
||||
return self.dec(x)
|
||||
|
||||
@@ -4,7 +4,12 @@ from torch_geometric.data import Batch
|
||||
|
||||
|
||||
class GraphSolver(LightningModule):
|
||||
def __init__(self, model: torch.nn.Module, loss: torch.nn.Module = None, unrolling_steps: int = 10):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
loss: torch.nn.Module = None,
|
||||
unrolling_steps: int = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
||||
@@ -18,7 +23,7 @@ class GraphSolver(LightningModule):
|
||||
edge_attr: torch.Tensor,
|
||||
):
|
||||
return self.model(x, c, edge_index, edge_attr)
|
||||
|
||||
|
||||
def _compute_loss_train(self, x, x_prev, y):
|
||||
return self.loss(x, y) + self.loss(x, x_prev)
|
||||
|
||||
@@ -27,7 +32,7 @@ class GraphSolver(LightningModule):
|
||||
|
||||
def _preprocess_batch(self, batch: Batch):
|
||||
return batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr
|
||||
|
||||
|
||||
def _log_loss(self, loss, batch, stage: str):
|
||||
self.log(
|
||||
f"{stage}_loss",
|
||||
@@ -41,13 +46,14 @@ class GraphSolver(LightningModule):
|
||||
|
||||
def training_step(self, batch: Batch, _):
|
||||
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||
loss = 0.0
|
||||
for _ in range(self.unrolling_steps):
|
||||
x_prev = x.detach()
|
||||
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
|
||||
loss = self.loss(x, y)
|
||||
loss += self.loss(x, y)
|
||||
self._log_loss(loss, batch, "train")
|
||||
return loss
|
||||
|
||||
|
||||
def validation_step(self, batch: Batch, _):
|
||||
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||
for _ in range(self.unrolling_steps):
|
||||
@@ -70,5 +76,5 @@ class GraphSolver(LightningModule):
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=5e-3)
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
||||
276
run.py
276
run.py
@@ -1,264 +1,24 @@
|
||||
import os
|
||||
import yaml
|
||||
import importlib
|
||||
from pina import Trainer
|
||||
import torch
|
||||
import numpy as np
|
||||
from pina.problem.zoo import SupervisedProblem
|
||||
from pina.solver import ReducedOrderModelSolver
|
||||
from pina.solver import SupervisedSolver
|
||||
from pina.optim import TorchOptimizer
|
||||
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from copy import deepcopy
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import random_split
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.tri import Triangulation
|
||||
|
||||
|
||||
def compute_error(u_true, u_pred):
|
||||
"""Compute the L2 error between true and predicted solutions."""
|
||||
return np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true)
|
||||
|
||||
def argparse():
|
||||
"""
|
||||
Parse command line arguments for training configuration.
|
||||
"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Train a model with specified "
|
||||
"parameters.")
|
||||
parser.add_argument('--config', type=str, required=True,
|
||||
help='Path to the configuration YAML file.')
|
||||
return parser.parse_args()
|
||||
|
||||
def load_config(config_file):
|
||||
"""
|
||||
Configure the training parameters.
|
||||
|
||||
:param str config_file: Path to the configuration file.
|
||||
:return: Configuration dictionary.
|
||||
:rtype: dict
|
||||
"""
|
||||
with open(config_file, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
def load_model(model_args):
|
||||
"""
|
||||
Load the model class and instantiate it with the provided arguments.
|
||||
|
||||
:param dict model_args: Arguments for the model class.
|
||||
:return: An instance of the model class.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
model_class = model_args.pop("model_class", "")
|
||||
module_path, class_name = model_class.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
model = cls(**model_args)
|
||||
return model
|
||||
|
||||
def load_data():
|
||||
"""
|
||||
Load the dataset from the specified path and preprocess it.
|
||||
|
||||
:param dict data_args: Arguments for loading the dataset.
|
||||
:return: Training and testing datasets, normalizers, and points.
|
||||
:rtype: tuple
|
||||
- u_train (torch.Tensor): Training simulations.
|
||||
- p_train (torch.Tensor): Training parameters.
|
||||
"""
|
||||
snapshots = load_dataset("SISSAmathLab/thermal-conduction", name="snapshots")["default"]
|
||||
geom = load_dataset("SISSAmathLab/thermal-conduction", name="geometry")["default"]
|
||||
points = torch.tensor(np.array(geom["points"]), dtype=torch.float32)
|
||||
|
||||
temperature = torch.tensor(np.array(snapshots["temperature"]), dtype=torch.float32)
|
||||
params = torch.tensor(np.array(snapshots["params"]), dtype=torch.float32)
|
||||
|
||||
train_size = int(0.8 * len(temperature))
|
||||
temp_train, temp_test = temperature[:train_size], temperature[train_size:]
|
||||
params_train, params_test = params[:train_size], params[train_size:]
|
||||
|
||||
return temp_train, params_train, temp_test, params_test, points
|
||||
|
||||
|
||||
def load_trainer(trainer_args, solver):
|
||||
"""
|
||||
Load and configure the Trainer for training.
|
||||
|
||||
:param dict trainer_args: Arguments for the Trainer.
|
||||
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
|
||||
instance to be used by the Trainer.
|
||||
:return: Configured Trainer instance.
|
||||
:rtype: Trainer
|
||||
"""
|
||||
patience = trainer_args.pop("patience", 100)
|
||||
es = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=patience,
|
||||
mode='min',
|
||||
verbose=True
|
||||
)
|
||||
|
||||
checkpoint = ModelCheckpoint(
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=1,
|
||||
filename='best_model',
|
||||
save_weights_only=True
|
||||
)
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=trainer_args.pop('log_dir', 'logs'),
|
||||
name=trainer_args.pop('name'),
|
||||
version=f"{trainer_args.pop('version'):03d}" if trainer_args.get('version', None) is not None else None
|
||||
)
|
||||
trainer_args['callbacks'] = [es, checkpoint]
|
||||
trainer_args['solver'] = solver
|
||||
trainer_args['logger'] = logger
|
||||
|
||||
trainer = Trainer(**trainer_args)
|
||||
return trainer
|
||||
|
||||
def load_optimizer(optim_args):
|
||||
"""
|
||||
Load the optimizer class and instantiate it with the provided arguments.
|
||||
|
||||
:param dict optim_args: Arguments for the optimizer class.
|
||||
:return: An instance of the TorchOptimizer class.
|
||||
:rtype: TorchOptimizer
|
||||
"""
|
||||
print("Loading optimizer with args:", optim_args)
|
||||
optim_class = optim_args.pop("optimizer_class", "")
|
||||
module_path, class_name = optim_class.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
return TorchOptimizer(
|
||||
cls,
|
||||
**optim_args
|
||||
)
|
||||
|
||||
def train(trainer):
|
||||
"""
|
||||
Train the model using the provided Trainer instance.
|
||||
|
||||
:param Trainer trainer: The Trainer instance configured for training.
|
||||
"""
|
||||
trainer.train()
|
||||
|
||||
def save_model(solver, trainer, problem, model, int_net):
|
||||
"""
|
||||
Save the trained model and its components to disk.
|
||||
|
||||
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
|
||||
instance containing the trained model.
|
||||
:param Trainer trainer: The Trainer instance used for training.
|
||||
:param ~pina.problem.zoo.SupervisedProblem problem: The problem instance
|
||||
associated with the solver.
|
||||
:param torch.nn.Module model: The trained model to be saved.
|
||||
:param torch.nn.Module int_net: The interpolation network, if used.
|
||||
"""
|
||||
model_path = trainer.logger.log_dir.replace("logs", "models")
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
if int_net is None:
|
||||
solver = SupervisedSolver.load_from_checkpoint(
|
||||
os.path.join(trainer.logger.log_dir, 'checkpoints',
|
||||
'best_model.ckpt'),
|
||||
problem=problem,
|
||||
model=model,
|
||||
use_lt=False)
|
||||
model = solver.model.cpu()
|
||||
model.eval()
|
||||
torch.save(model.state_dict(), os.path.join(model_path, 'model.pth'))
|
||||
if hasattr(model, 'pod'):
|
||||
torch.save(model.pod.basis, os.path.join(model_path,
|
||||
'pod_basis.pth')
|
||||
)
|
||||
return model
|
||||
else:
|
||||
solver = ReducedOrderModelSolver.load_from_checkpoint(
|
||||
os.path.join(trainer.logger.log_dir, 'checkpoints',
|
||||
'best_model.ckpt'),
|
||||
problem=problem,
|
||||
interpolation_network=int_net,
|
||||
reduction_network=model
|
||||
)
|
||||
int_net = solver.model["interpolation_network"].cpu()
|
||||
torch.save(int_net.state_dict(), os.path.join(model_path,
|
||||
'interpolation_network.pth')
|
||||
)
|
||||
model = solver.model["reduction_network"].cpu()
|
||||
torch.save(model.state_dict(), os.path.join(model_path,
|
||||
'reduction_network.pth')
|
||||
)
|
||||
if hasattr(model, 'pod'):
|
||||
torch.save(model.pod.basis, os.path.join(model_path,
|
||||
'pod_basis.pth')
|
||||
)
|
||||
from lightning import Trainer
|
||||
from ThermalSolver.module import GraphSolver
|
||||
from ThermalSolver.data_module import GraphDataModule
|
||||
from ThermalSolver.model.local_gno import GatingGNO
|
||||
from ThermalSolver.model.basic_gno import GNO
|
||||
|
||||
|
||||
def main():
|
||||
seed_everything(1999, workers=True)
|
||||
args = argparse()
|
||||
config = load_config(args.config)
|
||||
config_ = deepcopy(config)
|
||||
model_args = config.get("model", {})
|
||||
model = load_model(model_args)
|
||||
|
||||
if "interpolation" in config:
|
||||
model_args = config["interpolation"]
|
||||
int_net = load_model(model_args)
|
||||
else:
|
||||
int_net = None
|
||||
|
||||
temperature, params, _, _, points = load_data()
|
||||
problem = SupervisedProblem(output_=temperature, input_=params)
|
||||
optimizer = load_optimizer(config.get("optimizer", {}))
|
||||
if int_net is None:
|
||||
if hasattr(model, 'fit_pod'):
|
||||
model.fit_pod(temperature)
|
||||
solver = SupervisedSolver(
|
||||
problem=problem,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
use_lt=False
|
||||
)
|
||||
else:
|
||||
if "pod" in config_["model"]["model_class"]:
|
||||
model.fit(temperature)
|
||||
solver = ReducedOrderModelSolver(
|
||||
problem= problem,
|
||||
reduction_network=model,
|
||||
interpolation_network=int_net,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
trainer_args = config.get("trainer", {})
|
||||
trainer = load_trainer(trainer_args, solver)
|
||||
train(trainer)
|
||||
model_test = save_model(solver, trainer, problem, model, int_net)
|
||||
model_test.eval()
|
||||
temp_pred = model_test(params).detach().numpy()
|
||||
temp_true = temperature.detach().numpy()
|
||||
error = compute_error(temp_true, temp_pred)
|
||||
print(points.shape)
|
||||
tria = Triangulation(points[0,:,0], points[0,:,1])
|
||||
|
||||
levels_main = torch.linspace(0, temp_true[0].max().item(), steps=100)
|
||||
abs_error = np.abs(temp_true[0]- temp_pred[0])
|
||||
levels_diff = torch.linspace(0, abs_error.max().item(), steps=100)
|
||||
fig, axs = plt.subplots(1,3, figsize=(15,5))
|
||||
im = axs[0].tricontourf(tria, temp_true[0], levels=levels_main)
|
||||
axs[0].set_title('True Temperature')
|
||||
fig.colorbar(im, ax=axs[0])
|
||||
im = axs[1].tricontourf(tria, temp_pred[0], levels=levels_main)
|
||||
axs[1].set_title('Predicted Temperature')
|
||||
fig.colorbar(im, ax=axs[1])
|
||||
im = axs[2].tricontourf(tria, np.abs(temp_true[0]- temp_pred[0]), levels=levels_diff)
|
||||
axs[2].set_title('Absolute Error')
|
||||
fig.colorbar(im, ax=axs[2])
|
||||
plt.savefig('temperature_comparison.png')
|
||||
print(f"L2 error on training set: {error:.6f}")
|
||||
trainer = Trainer(max_epochs=100, accelerator="cuda", devices=1)
|
||||
data_module = GraphDataModule(
|
||||
hf_repo="SISSAmathLab/thermal-conduction",
|
||||
split_name="easy",
|
||||
train_size=0.8,
|
||||
val_size=0.1,
|
||||
test_size=0.1,
|
||||
batch_size=8,
|
||||
)
|
||||
model = GatingGNO(x_ch_node=1, f_ch_node=1, hidden=16, layers=8, edge_ch=3, out_ch=1)
|
||||
solver = GraphSolver(model)
|
||||
trainer.fit(solver, datamodule=data_module)
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user