export tutorials changed in 9c60f61 (#643)

Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2025-09-15 19:35:24 +02:00
committed by GitHub
parent 9c60f616b7
commit ef75f13bcb
5 changed files with 12073 additions and 125 deletions

View File

@@ -40,7 +40,10 @@
"import torch\n",
"from torch import nn\n",
"from torch_geometric.nn import GMMConv\n",
"from torch_geometric.data import Data, Batch # alternatively, from pina.graph import Graph, LabelBatch\n",
"from torch_geometric.data import (\n",
" Data,\n",
" Batch,\n",
") # alternatively, from pina.graph import Graph, LabelBatch\n",
"from torch_geometric.utils import to_dense_batch\n",
"\n",
"import matplotlib.pyplot as plt\n",
@@ -105,17 +108,17 @@
"# u, params -> solution field, parameters\n",
"\n",
"data = torch.load(\"holed_poisson.pt\")\n",
"x = data['x']\n",
"y = data['y']\n",
"edge_index = data['edge_index']\n",
"u = data['u']\n",
"triang = data['triang']\n",
"params = data['mu']\n",
"x = data[\"x\"]\n",
"y = data[\"y\"]\n",
"edge_index = data[\"edge_index\"]\n",
"u = data[\"u\"]\n",
"triang = data[\"triang\"]\n",
"params = data[\"mu\"]\n",
"\n",
"# simple plot\n",
"plt.figure(figsize=(4, 4))\n",
"plt.tricontourf(x[:, 10], y[:, 10], triang, u[:, 10], 100, cmap='jet')\n",
"plt.scatter(params[10, 0], params[10, 1], c='r', marker=\"x\", s=100)\n",
"plt.tricontourf(x[:, 10], y[:, 10], triang, u[:, 10], 100, cmap=\"jet\")\n",
"plt.scatter(params[10, 0], params[10, 1], c=\"r\", marker=\"x\", s=100)\n",
"plt.tight_layout()\n",
"plt.show()"
]
@@ -267,7 +270,7 @@
" # edge attributes and weights\n",
" ei, ej = pos[edge_index[0]], pos[edge_index[1]] # [num_edges, 2]\n",
" edge_attr = torch.abs(ej - ei) # relative offsets\n",
" edge_weight = edge_attr.norm(p=2, dim=1, keepdim=True) # Euclidean distance\n",
" edge_weight = edge_attr.norm(p=2, dim=1, keepdim=True) # Euclidean distance\n",
" # node features (solution values)\n",
" node_features = u[:, g].unsqueeze(-1) # [num_nodes, 1]\n",
" # build PyG graph\n",
@@ -327,7 +330,11 @@
" hidden_channels=[1, 1], bottleneck=8, input_size=1352, ffn=200, act=nn.ELU\n",
")\n",
"interpolation_network = FeedForward(\n",
" input_dimensions=2, output_dimensions=8, n_layers=2, inner_size=200, func=nn.Tanh\n",
" input_dimensions=2,\n",
" output_dimensions=8,\n",
" n_layers=2,\n",
" inner_size=200,\n",
" func=nn.Tanh,\n",
")"
]
},
@@ -361,6 +368,7 @@
" output, target, reduction=self.reduction\n",
" )\n",
"\n",
"\n",
"# Define the solver\n",
"solver = ReducedOrderModelSolver(\n",
" problem=problem,\n",
@@ -393,7 +401,7 @@
" max_epochs=300,\n",
" train_size=0.3,\n",
" val_size=0.7,\n",
" test_size=0.,\n",
" test_size=0.0,\n",
" shuffle=True,\n",
")\n",
"trainer.train()"
@@ -481,10 +489,10 @@
" vmin=vmin,\n",
" vmax=vmax,\n",
")\n",
"plt.title('GCA-ROM')\n",
"plt.title(\"GCA-ROM\")\n",
"plt.colorbar()\n",
"plt.subplot(1, 3, 2)\n",
"plt.title('True')\n",
"plt.title(\"True\")\n",
"plt.tricontourf(\n",
" x[:, idx_to_plot],\n",
" y[:, idx_to_plot],\n",
@@ -497,8 +505,15 @@
")\n",
"plt.colorbar()\n",
"plt.subplot(1, 3, 3)\n",
"plt.title('Square Error')\n",
"plt.tricontourf(x[:, idx_to_plot], y[:, idx_to_plot], triang, (u-out).pow(2)[:, idx_to_plot], 100, cmap='jet')\n",
"plt.title(\"Square Error\")\n",
"plt.tricontourf(\n",
" x[:, idx_to_plot],\n",
" y[:, idx_to_plot],\n",
" triang,\n",
" (u - out).pow(2)[:, idx_to_plot],\n",
" 100,\n",
" cmap=\"jet\",\n",
")\n",
"plt.colorbar()\n",
"plt.ticklabel_format()\n",
"plt.show()"