Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9c60f616b7
commit
ef75f13bcb
47
tutorials/tutorial22/tutorial.ipynb
vendored
47
tutorials/tutorial22/tutorial.ipynb
vendored
@@ -40,7 +40,10 @@
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from torch_geometric.nn import GMMConv\n",
|
||||
"from torch_geometric.data import Data, Batch # alternatively, from pina.graph import Graph, LabelBatch\n",
|
||||
"from torch_geometric.data import (\n",
|
||||
" Data,\n",
|
||||
" Batch,\n",
|
||||
") # alternatively, from pina.graph import Graph, LabelBatch\n",
|
||||
"from torch_geometric.utils import to_dense_batch\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
@@ -105,17 +108,17 @@
|
||||
"# u, params -> solution field, parameters\n",
|
||||
"\n",
|
||||
"data = torch.load(\"holed_poisson.pt\")\n",
|
||||
"x = data['x']\n",
|
||||
"y = data['y']\n",
|
||||
"edge_index = data['edge_index']\n",
|
||||
"u = data['u']\n",
|
||||
"triang = data['triang']\n",
|
||||
"params = data['mu']\n",
|
||||
"x = data[\"x\"]\n",
|
||||
"y = data[\"y\"]\n",
|
||||
"edge_index = data[\"edge_index\"]\n",
|
||||
"u = data[\"u\"]\n",
|
||||
"triang = data[\"triang\"]\n",
|
||||
"params = data[\"mu\"]\n",
|
||||
"\n",
|
||||
"# simple plot\n",
|
||||
"plt.figure(figsize=(4, 4))\n",
|
||||
"plt.tricontourf(x[:, 10], y[:, 10], triang, u[:, 10], 100, cmap='jet')\n",
|
||||
"plt.scatter(params[10, 0], params[10, 1], c='r', marker=\"x\", s=100)\n",
|
||||
"plt.tricontourf(x[:, 10], y[:, 10], triang, u[:, 10], 100, cmap=\"jet\")\n",
|
||||
"plt.scatter(params[10, 0], params[10, 1], c=\"r\", marker=\"x\", s=100)\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
@@ -267,7 +270,7 @@
|
||||
" # edge attributes and weights\n",
|
||||
" ei, ej = pos[edge_index[0]], pos[edge_index[1]] # [num_edges, 2]\n",
|
||||
" edge_attr = torch.abs(ej - ei) # relative offsets\n",
|
||||
" edge_weight = edge_attr.norm(p=2, dim=1, keepdim=True) # Euclidean distance\n",
|
||||
" edge_weight = edge_attr.norm(p=2, dim=1, keepdim=True) # Euclidean distance\n",
|
||||
" # node features (solution values)\n",
|
||||
" node_features = u[:, g].unsqueeze(-1) # [num_nodes, 1]\n",
|
||||
" # build PyG graph\n",
|
||||
@@ -327,7 +330,11 @@
|
||||
" hidden_channels=[1, 1], bottleneck=8, input_size=1352, ffn=200, act=nn.ELU\n",
|
||||
")\n",
|
||||
"interpolation_network = FeedForward(\n",
|
||||
" input_dimensions=2, output_dimensions=8, n_layers=2, inner_size=200, func=nn.Tanh\n",
|
||||
" input_dimensions=2,\n",
|
||||
" output_dimensions=8,\n",
|
||||
" n_layers=2,\n",
|
||||
" inner_size=200,\n",
|
||||
" func=nn.Tanh,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -361,6 +368,7 @@
|
||||
" output, target, reduction=self.reduction\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Define the solver\n",
|
||||
"solver = ReducedOrderModelSolver(\n",
|
||||
" problem=problem,\n",
|
||||
@@ -393,7 +401,7 @@
|
||||
" max_epochs=300,\n",
|
||||
" train_size=0.3,\n",
|
||||
" val_size=0.7,\n",
|
||||
" test_size=0.,\n",
|
||||
" test_size=0.0,\n",
|
||||
" shuffle=True,\n",
|
||||
")\n",
|
||||
"trainer.train()"
|
||||
@@ -481,10 +489,10 @@
|
||||
" vmin=vmin,\n",
|
||||
" vmax=vmax,\n",
|
||||
")\n",
|
||||
"plt.title('GCA-ROM')\n",
|
||||
"plt.title(\"GCA-ROM\")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.subplot(1, 3, 2)\n",
|
||||
"plt.title('True')\n",
|
||||
"plt.title(\"True\")\n",
|
||||
"plt.tricontourf(\n",
|
||||
" x[:, idx_to_plot],\n",
|
||||
" y[:, idx_to_plot],\n",
|
||||
@@ -497,8 +505,15 @@
|
||||
")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.subplot(1, 3, 3)\n",
|
||||
"plt.title('Square Error')\n",
|
||||
"plt.tricontourf(x[:, idx_to_plot], y[:, idx_to_plot], triang, (u-out).pow(2)[:, idx_to_plot], 100, cmap='jet')\n",
|
||||
"plt.title(\"Square Error\")\n",
|
||||
"plt.tricontourf(\n",
|
||||
" x[:, idx_to_plot],\n",
|
||||
" y[:, idx_to_plot],\n",
|
||||
" triang,\n",
|
||||
" (u - out).pow(2)[:, idx_to_plot],\n",
|
||||
" 100,\n",
|
||||
" cmap=\"jet\",\n",
|
||||
")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.ticklabel_format()\n",
|
||||
"plt.show()"
|
||||
|
||||
Reference in New Issue
Block a user