implement first GNO
This commit is contained in:
@@ -64,7 +64,7 @@ class GraphDataModule(LightningDataModule):
|
|||||||
edge_attr = torch.cat(
|
edge_attr = torch.cat(
|
||||||
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
return Data(
|
return Data(
|
||||||
x=boundary_vales.unsqueeze(-1),
|
x=boundary_vales.unsqueeze(-1),
|
||||||
c=conductivity.unsqueeze(-1),
|
c=conductivity.unsqueeze(-1),
|
||||||
|
|||||||
25
ThermalSolver/model/basic_gno.py
Normal file
25
ThermalSolver/model/basic_gno.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from pina.model import GraphNeuralOperator
|
||||||
|
import torch
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
|
|
||||||
|
class GNO(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
lifting_operator = torch.nn.Linear(x_ch_node + f_ch_node, hidden)
|
||||||
|
self.gno = GraphNeuralOperator(
|
||||||
|
lifting_operator=lifting_operator,
|
||||||
|
projection_operator=torch.nn.Linear(hidden, out_ch),
|
||||||
|
edge_features=edge_ch,
|
||||||
|
n_layers=layers,
|
||||||
|
internal_n_layers=2,
|
||||||
|
shared_weights=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c, edge_index, edge_attr):
|
||||||
|
x = torch.cat([x, c], dim=-1)
|
||||||
|
x = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
||||||
|
return self.gno(x)
|
||||||
@@ -8,19 +8,16 @@ class FiLM(nn.Module):
|
|||||||
def __init__(self, c_ch, h_ch):
|
def __init__(self, c_ch, h_ch):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Linear(c_ch, 2*h_ch),
|
nn.Linear(c_ch, 2 * h_ch), nn.SiLU(), nn.Linear(2 * h_ch, 2 * h_ch)
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(2*h_ch, 2*h_ch)
|
|
||||||
)
|
)
|
||||||
# init to identity: gamma≈0 (so 1+gamma=1), beta=0
|
# init to identity: gamma≈0 (so 1+gamma=1), beta=0
|
||||||
nn.init.zeros_(self.net[-1].weight)
|
nn.init.zeros_(self.net[-1].weight)
|
||||||
nn.init.zeros_(self.net[-1].bias)
|
nn.init.zeros_(self.net[-1].bias)
|
||||||
self.norm = nn.LayerNorm(h_ch)
|
|
||||||
|
|
||||||
def forward(self, h, c):
|
def forward(self, h, c):
|
||||||
gb = self.net(c)
|
gb = self.net(c)
|
||||||
gamma, beta = gb.chunk(2, dim=-1)
|
gamma, beta = gb.chunk(2, dim=-1)
|
||||||
return (1 + gamma) * self.norm(h) + beta
|
return (1 + gamma) * h + beta
|
||||||
|
|
||||||
|
|
||||||
class ConditionalGNOBlock(MessagePassing):
|
class ConditionalGNOBlock(MessagePassing):
|
||||||
@@ -28,46 +25,35 @@ class ConditionalGNOBlock(MessagePassing):
|
|||||||
Message passing with FiLM applied to the MESSAGE m_ij,
|
Message passing with FiLM applied to the MESSAGE m_ij,
|
||||||
using edge context c_ij = (c_i + c_j)/2.
|
using edge context c_ij = (c_i + c_j)/2.
|
||||||
"""
|
"""
|
||||||
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
|
|
||||||
|
def __init__(self, hidden_ch, edge_ch=0, aggr="add"):
|
||||||
super().__init__(aggr=aggr, node_dim=0)
|
super().__init__(aggr=aggr, node_dim=0)
|
||||||
self.pre_norm = nn.LayerNorm(hidden_ch)
|
|
||||||
|
|
||||||
# raw message builder
|
|
||||||
self.msg = nn.Sequential(
|
|
||||||
nn.Linear(2*hidden_ch + edge_ch, 2*hidden_ch),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(2*hidden_ch, hidden_ch)
|
|
||||||
)
|
|
||||||
|
|
||||||
# FiLM over the message (per-edge)
|
# FiLM over the message (per-edge)
|
||||||
self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
|
self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
|
||||||
|
self.edge_attr_net = nn.Sequential(
|
||||||
# node update with residual
|
nn.Linear(edge_ch, hidden_ch // 2),
|
||||||
self.update_mlp = nn.Sequential(
|
|
||||||
nn.Linear(2*hidden_ch, hidden_ch),
|
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(hidden_ch, hidden_ch)
|
nn.Linear(hidden_ch // 2, hidden_ch),
|
||||||
|
)
|
||||||
|
self.x_net = nn.Sequential(
|
||||||
|
nn.Linear(hidden_ch, hidden_ch * 2),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_ch * 2, hidden_ch),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, c, edge_index, edge_attr=None):
|
def forward(self, x, c, edge_index, edge_attr=None):
|
||||||
# pre-norm helps stability
|
return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
|
||||||
x_in = x
|
|
||||||
x = self.pre_norm(x)
|
|
||||||
m = self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
|
|
||||||
out = self.update_mlp(torch.cat([x_in, m], dim=-1))
|
|
||||||
return x_in + out # residual
|
|
||||||
|
|
||||||
def message(self, x_i, x_j, c_i, c_j, edge_attr):
|
def update(self, aggr_out, x):
|
||||||
|
return self.x_net(x) + aggr_out
|
||||||
|
|
||||||
|
def message(self, x_j, c_i, c_j, edge_attr):
|
||||||
|
# c_ij = (c_i + c_j)/2
|
||||||
|
c_ij = 0.5 * (c_i + c_j)
|
||||||
|
m = self.film_msg(x_j, c_ij)
|
||||||
if edge_attr is not None:
|
if edge_attr is not None:
|
||||||
m_in = torch.cat([x_i, x_j, edge_attr], dim=-1)
|
a_ij = self.edge_attr_net(edge_attr)
|
||||||
else:
|
m = m * a_ij
|
||||||
m_in = torch.cat([x_i, x_j], dim=-1)
|
|
||||||
|
|
||||||
m_raw = self.msg(m_in)
|
|
||||||
|
|
||||||
# edge conditioning: simple mean
|
|
||||||
c_ctx = 0.5 * (c_i + c_j)
|
|
||||||
m = self.film_msg(m_raw, c_ctx)
|
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +65,10 @@ class GatingGNO(nn.Module):
|
|||||||
Out:
|
Out:
|
||||||
y : [N, out_ch]
|
y : [N, out_ch]
|
||||||
"""
|
"""
|
||||||
def __init__(self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1):
|
|
||||||
|
def __init__(
|
||||||
|
self, x_ch_node, f_ch_node, hidden, layers, edge_ch=0, out_ch=1
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder_x = nn.Sequential(
|
self.encoder_x = nn.Sequential(
|
||||||
nn.Linear(x_ch_node, hidden // 2),
|
nn.Linear(x_ch_node, hidden // 2),
|
||||||
@@ -92,12 +81,15 @@ class GatingGNO(nn.Module):
|
|||||||
nn.Linear(hidden // 2, hidden),
|
nn.Linear(hidden // 2, hidden),
|
||||||
)
|
)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch) for _ in range(layers)]
|
[
|
||||||
|
ConditionalGNOBlock(hidden_ch=hidden, edge_ch=edge_ch)
|
||||||
|
for _ in range(layers)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.dec = nn.Sequential(
|
self.dec = nn.Sequential(
|
||||||
nn.LayerNorm(hidden),
|
nn.Linear(hidden, hidden // 2),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(hidden, out_ch)
|
nn.Linear(hidden // 2, out_ch),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, c, edge_index, edge_attr=None):
|
def forward(self, x, c, edge_index, edge_attr=None):
|
||||||
@@ -105,4 +97,4 @@ class GatingGNO(nn.Module):
|
|||||||
c = self.encoder_c(c) # [N,H]
|
c = self.encoder_c(c) # [N,H]
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||||
return self.dec(x)
|
return self.dec(x)
|
||||||
|
|||||||
@@ -4,7 +4,12 @@ from torch_geometric.data import Batch
|
|||||||
|
|
||||||
|
|
||||||
class GraphSolver(LightningModule):
|
class GraphSolver(LightningModule):
|
||||||
def __init__(self, model: torch.nn.Module, loss: torch.nn.Module = None, unrolling_steps: int = 10):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
loss: torch.nn.Module = None,
|
||||||
|
unrolling_steps: int = 10,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
||||||
@@ -18,7 +23,7 @@ class GraphSolver(LightningModule):
|
|||||||
edge_attr: torch.Tensor,
|
edge_attr: torch.Tensor,
|
||||||
):
|
):
|
||||||
return self.model(x, c, edge_index, edge_attr)
|
return self.model(x, c, edge_index, edge_attr)
|
||||||
|
|
||||||
def _compute_loss_train(self, x, x_prev, y):
|
def _compute_loss_train(self, x, x_prev, y):
|
||||||
return self.loss(x, y) + self.loss(x, x_prev)
|
return self.loss(x, y) + self.loss(x, x_prev)
|
||||||
|
|
||||||
@@ -27,7 +32,7 @@ class GraphSolver(LightningModule):
|
|||||||
|
|
||||||
def _preprocess_batch(self, batch: Batch):
|
def _preprocess_batch(self, batch: Batch):
|
||||||
return batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr
|
return batch.x, batch.y, batch.c, batch.edge_index, batch.edge_attr
|
||||||
|
|
||||||
def _log_loss(self, loss, batch, stage: str):
|
def _log_loss(self, loss, batch, stage: str):
|
||||||
self.log(
|
self.log(
|
||||||
f"{stage}_loss",
|
f"{stage}_loss",
|
||||||
@@ -41,13 +46,14 @@ class GraphSolver(LightningModule):
|
|||||||
|
|
||||||
def training_step(self, batch: Batch, _):
|
def training_step(self, batch: Batch, _):
|
||||||
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||||
|
loss = 0.0
|
||||||
for _ in range(self.unrolling_steps):
|
for _ in range(self.unrolling_steps):
|
||||||
x_prev = x.detach()
|
x_prev = x.detach()
|
||||||
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
|
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
|
||||||
loss = self.loss(x, y)
|
loss += self.loss(x, y)
|
||||||
self._log_loss(loss, batch, "train")
|
self._log_loss(loss, batch, "train")
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch: Batch, _):
|
def validation_step(self, batch: Batch, _):
|
||||||
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||||
for _ in range(self.unrolling_steps):
|
for _ in range(self.unrolling_steps):
|
||||||
@@ -70,5 +76,5 @@ class GraphSolver(LightningModule):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=5e-3)
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|||||||
276
run.py
276
run.py
@@ -1,264 +1,24 @@
|
|||||||
import os
|
from lightning import Trainer
|
||||||
import yaml
|
from ThermalSolver.module import GraphSolver
|
||||||
import importlib
|
from ThermalSolver.data_module import GraphDataModule
|
||||||
from pina import Trainer
|
from ThermalSolver.model.local_gno import GatingGNO
|
||||||
import torch
|
from ThermalSolver.model.basic_gno import GNO
|
||||||
import numpy as np
|
|
||||||
from pina.problem.zoo import SupervisedProblem
|
|
||||||
from pina.solver import ReducedOrderModelSolver
|
|
||||||
from pina.solver import SupervisedSolver
|
|
||||||
from pina.optim import TorchOptimizer
|
|
||||||
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
|
||||||
from lightning.pytorch import seed_everything
|
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
|
||||||
from copy import deepcopy
|
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import random_split
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
from matplotlib.tri import Triangulation
|
|
||||||
|
|
||||||
|
|
||||||
def compute_error(u_true, u_pred):
|
|
||||||
"""Compute the L2 error between true and predicted solutions."""
|
|
||||||
return np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true)
|
|
||||||
|
|
||||||
def argparse():
|
|
||||||
"""
|
|
||||||
Parse command line arguments for training configuration.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser(description="Train a model with specified "
|
|
||||||
"parameters.")
|
|
||||||
parser.add_argument('--config', type=str, required=True,
|
|
||||||
help='Path to the configuration YAML file.')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
def load_config(config_file):
|
|
||||||
"""
|
|
||||||
Configure the training parameters.
|
|
||||||
|
|
||||||
:param str config_file: Path to the configuration file.
|
|
||||||
:return: Configuration dictionary.
|
|
||||||
:rtype: dict
|
|
||||||
"""
|
|
||||||
with open(config_file, "r") as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def load_model(model_args):
|
|
||||||
"""
|
|
||||||
Load the model class and instantiate it with the provided arguments.
|
|
||||||
|
|
||||||
:param dict model_args: Arguments for the model class.
|
|
||||||
:return: An instance of the model class.
|
|
||||||
:rtype: torch.nn.Module
|
|
||||||
"""
|
|
||||||
model_class = model_args.pop("model_class", "")
|
|
||||||
module_path, class_name = model_class.rsplit(".", 1)
|
|
||||||
module = importlib.import_module(module_path)
|
|
||||||
cls = getattr(module, class_name)
|
|
||||||
model = cls(**model_args)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def load_data():
|
|
||||||
"""
|
|
||||||
Load the dataset from the specified path and preprocess it.
|
|
||||||
|
|
||||||
:param dict data_args: Arguments for loading the dataset.
|
|
||||||
:return: Training and testing datasets, normalizers, and points.
|
|
||||||
:rtype: tuple
|
|
||||||
- u_train (torch.Tensor): Training simulations.
|
|
||||||
- p_train (torch.Tensor): Training parameters.
|
|
||||||
"""
|
|
||||||
snapshots = load_dataset("SISSAmathLab/thermal-conduction", name="snapshots")["default"]
|
|
||||||
geom = load_dataset("SISSAmathLab/thermal-conduction", name="geometry")["default"]
|
|
||||||
points = torch.tensor(np.array(geom["points"]), dtype=torch.float32)
|
|
||||||
|
|
||||||
temperature = torch.tensor(np.array(snapshots["temperature"]), dtype=torch.float32)
|
|
||||||
params = torch.tensor(np.array(snapshots["params"]), dtype=torch.float32)
|
|
||||||
|
|
||||||
train_size = int(0.8 * len(temperature))
|
|
||||||
temp_train, temp_test = temperature[:train_size], temperature[train_size:]
|
|
||||||
params_train, params_test = params[:train_size], params[train_size:]
|
|
||||||
|
|
||||||
return temp_train, params_train, temp_test, params_test, points
|
|
||||||
|
|
||||||
|
|
||||||
def load_trainer(trainer_args, solver):
|
|
||||||
"""
|
|
||||||
Load and configure the Trainer for training.
|
|
||||||
|
|
||||||
:param dict trainer_args: Arguments for the Trainer.
|
|
||||||
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
|
|
||||||
instance to be used by the Trainer.
|
|
||||||
:return: Configured Trainer instance.
|
|
||||||
:rtype: Trainer
|
|
||||||
"""
|
|
||||||
patience = trainer_args.pop("patience", 100)
|
|
||||||
es = EarlyStopping(
|
|
||||||
monitor='val_loss',
|
|
||||||
patience=patience,
|
|
||||||
mode='min',
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoint = ModelCheckpoint(
|
|
||||||
monitor='val_loss',
|
|
||||||
mode='min',
|
|
||||||
save_top_k=1,
|
|
||||||
filename='best_model',
|
|
||||||
save_weights_only=True
|
|
||||||
)
|
|
||||||
logger = TensorBoardLogger(
|
|
||||||
save_dir=trainer_args.pop('log_dir', 'logs'),
|
|
||||||
name=trainer_args.pop('name'),
|
|
||||||
version=f"{trainer_args.pop('version'):03d}" if trainer_args.get('version', None) is not None else None
|
|
||||||
)
|
|
||||||
trainer_args['callbacks'] = [es, checkpoint]
|
|
||||||
trainer_args['solver'] = solver
|
|
||||||
trainer_args['logger'] = logger
|
|
||||||
|
|
||||||
trainer = Trainer(**trainer_args)
|
|
||||||
return trainer
|
|
||||||
|
|
||||||
def load_optimizer(optim_args):
|
|
||||||
"""
|
|
||||||
Load the optimizer class and instantiate it with the provided arguments.
|
|
||||||
|
|
||||||
:param dict optim_args: Arguments for the optimizer class.
|
|
||||||
:return: An instance of the TorchOptimizer class.
|
|
||||||
:rtype: TorchOptimizer
|
|
||||||
"""
|
|
||||||
print("Loading optimizer with args:", optim_args)
|
|
||||||
optim_class = optim_args.pop("optimizer_class", "")
|
|
||||||
module_path, class_name = optim_class.rsplit(".", 1)
|
|
||||||
module = importlib.import_module(module_path)
|
|
||||||
cls = getattr(module, class_name)
|
|
||||||
return TorchOptimizer(
|
|
||||||
cls,
|
|
||||||
**optim_args
|
|
||||||
)
|
|
||||||
|
|
||||||
def train(trainer):
|
|
||||||
"""
|
|
||||||
Train the model using the provided Trainer instance.
|
|
||||||
|
|
||||||
:param Trainer trainer: The Trainer instance configured for training.
|
|
||||||
"""
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
def save_model(solver, trainer, problem, model, int_net):
|
|
||||||
"""
|
|
||||||
Save the trained model and its components to disk.
|
|
||||||
|
|
||||||
:param ~pina.solver.solver_interface.SolverInterface solver: The solver
|
|
||||||
instance containing the trained model.
|
|
||||||
:param Trainer trainer: The Trainer instance used for training.
|
|
||||||
:param ~pina.problem.zoo.SupervisedProblem problem: The problem instance
|
|
||||||
associated with the solver.
|
|
||||||
:param torch.nn.Module model: The trained model to be saved.
|
|
||||||
:param torch.nn.Module int_net: The interpolation network, if used.
|
|
||||||
"""
|
|
||||||
model_path = trainer.logger.log_dir.replace("logs", "models")
|
|
||||||
os.makedirs(model_path, exist_ok=True)
|
|
||||||
if int_net is None:
|
|
||||||
solver = SupervisedSolver.load_from_checkpoint(
|
|
||||||
os.path.join(trainer.logger.log_dir, 'checkpoints',
|
|
||||||
'best_model.ckpt'),
|
|
||||||
problem=problem,
|
|
||||||
model=model,
|
|
||||||
use_lt=False)
|
|
||||||
model = solver.model.cpu()
|
|
||||||
model.eval()
|
|
||||||
torch.save(model.state_dict(), os.path.join(model_path, 'model.pth'))
|
|
||||||
if hasattr(model, 'pod'):
|
|
||||||
torch.save(model.pod.basis, os.path.join(model_path,
|
|
||||||
'pod_basis.pth')
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
else:
|
|
||||||
solver = ReducedOrderModelSolver.load_from_checkpoint(
|
|
||||||
os.path.join(trainer.logger.log_dir, 'checkpoints',
|
|
||||||
'best_model.ckpt'),
|
|
||||||
problem=problem,
|
|
||||||
interpolation_network=int_net,
|
|
||||||
reduction_network=model
|
|
||||||
)
|
|
||||||
int_net = solver.model["interpolation_network"].cpu()
|
|
||||||
torch.save(int_net.state_dict(), os.path.join(model_path,
|
|
||||||
'interpolation_network.pth')
|
|
||||||
)
|
|
||||||
model = solver.model["reduction_network"].cpu()
|
|
||||||
torch.save(model.state_dict(), os.path.join(model_path,
|
|
||||||
'reduction_network.pth')
|
|
||||||
)
|
|
||||||
if hasattr(model, 'pod'):
|
|
||||||
torch.save(model.pod.basis, os.path.join(model_path,
|
|
||||||
'pod_basis.pth')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
seed_everything(1999, workers=True)
|
trainer = Trainer(max_epochs=100, accelerator="cuda", devices=1)
|
||||||
args = argparse()
|
data_module = GraphDataModule(
|
||||||
config = load_config(args.config)
|
hf_repo="SISSAmathLab/thermal-conduction",
|
||||||
config_ = deepcopy(config)
|
split_name="easy",
|
||||||
model_args = config.get("model", {})
|
train_size=0.8,
|
||||||
model = load_model(model_args)
|
val_size=0.1,
|
||||||
|
test_size=0.1,
|
||||||
if "interpolation" in config:
|
batch_size=8,
|
||||||
model_args = config["interpolation"]
|
)
|
||||||
int_net = load_model(model_args)
|
model = GatingGNO(x_ch_node=1, f_ch_node=1, hidden=16, layers=8, edge_ch=3, out_ch=1)
|
||||||
else:
|
solver = GraphSolver(model)
|
||||||
int_net = None
|
trainer.fit(solver, datamodule=data_module)
|
||||||
|
print("Done!")
|
||||||
temperature, params, _, _, points = load_data()
|
|
||||||
problem = SupervisedProblem(output_=temperature, input_=params)
|
|
||||||
optimizer = load_optimizer(config.get("optimizer", {}))
|
|
||||||
if int_net is None:
|
|
||||||
if hasattr(model, 'fit_pod'):
|
|
||||||
model.fit_pod(temperature)
|
|
||||||
solver = SupervisedSolver(
|
|
||||||
problem=problem,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
use_lt=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if "pod" in config_["model"]["model_class"]:
|
|
||||||
model.fit(temperature)
|
|
||||||
solver = ReducedOrderModelSolver(
|
|
||||||
problem= problem,
|
|
||||||
reduction_network=model,
|
|
||||||
interpolation_network=int_net,
|
|
||||||
optimizer=optimizer,
|
|
||||||
)
|
|
||||||
trainer_args = config.get("trainer", {})
|
|
||||||
trainer = load_trainer(trainer_args, solver)
|
|
||||||
train(trainer)
|
|
||||||
model_test = save_model(solver, trainer, problem, model, int_net)
|
|
||||||
model_test.eval()
|
|
||||||
temp_pred = model_test(params).detach().numpy()
|
|
||||||
temp_true = temperature.detach().numpy()
|
|
||||||
error = compute_error(temp_true, temp_pred)
|
|
||||||
print(points.shape)
|
|
||||||
tria = Triangulation(points[0,:,0], points[0,:,1])
|
|
||||||
|
|
||||||
levels_main = torch.linspace(0, temp_true[0].max().item(), steps=100)
|
|
||||||
abs_error = np.abs(temp_true[0]- temp_pred[0])
|
|
||||||
levels_diff = torch.linspace(0, abs_error.max().item(), steps=100)
|
|
||||||
fig, axs = plt.subplots(1,3, figsize=(15,5))
|
|
||||||
im = axs[0].tricontourf(tria, temp_true[0], levels=levels_main)
|
|
||||||
axs[0].set_title('True Temperature')
|
|
||||||
fig.colorbar(im, ax=axs[0])
|
|
||||||
im = axs[1].tricontourf(tria, temp_pred[0], levels=levels_main)
|
|
||||||
axs[1].set_title('Predicted Temperature')
|
|
||||||
fig.colorbar(im, ax=axs[1])
|
|
||||||
im = axs[2].tricontourf(tria, np.abs(temp_true[0]- temp_pred[0]), levels=levels_diff)
|
|
||||||
axs[2].set_title('Absolute Error')
|
|
||||||
fig.colorbar(im, ax=axs[2])
|
|
||||||
plt.savefig('temperature_comparison.png')
|
|
||||||
print(f"L2 error on training set: {error:.6f}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
Reference in New Issue
Block a user