Files
PINA/tutorials/tutorial5/tutorial.ipynb
Dario Coscia 29b14ee9b6 Update Tutorials (#544)
* update tutorials
* tutorial guidelines
* doc
2025-04-23 18:53:30 +02:00

466 lines
32 KiB
Plaintext
Vendored

{
"cells": [
{
"cell_type": "markdown",
"id": "e80567a6",
"metadata": {},
"source": [
"# Tutorial: Modeling 2D Darcy Flow with the Fourier Neural Operator\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial5/tutorial.ipynb)\n",
"\n",
"In this tutorial, we are going to solve the **Darcy flow problem** in two dimensions, as presented in the paper [*Fourier Neural Operator for Parametric Partial Differential Equations*](https://openreview.net/pdf?id=c8P9NQVtmnO).\n",
"\n",
"We begin by importing the necessary modules for the tutorial:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f2744dc",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-19T13:35:28.837348Z",
"start_time": "2024-09-19T13:35:27.611334Z"
}
},
"outputs": [],
"source": [
"## routine needed to run the notebook on Google Colab\n",
"try:\n",
" import google.colab\n",
"\n",
" IN_COLAB = True\n",
"except:\n",
" IN_COLAB = False\n",
"if IN_COLAB:\n",
" !pip install \"pina-mathlab[tutorial]\"\n",
" !pip install scipy\n",
" # get the data\n",
" !wget https://github.com/mathLab/PINA/raw/refs/heads/master/tutorials/tutorial5/Data_Darcy.mat\n",
"\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import warnings\n",
"\n",
"from scipy import io\n",
"from pina.model import FNO, FeedForward\n",
"from pina import Trainer\n",
"from pina.solver import SupervisedSolver\n",
"from pina.problem.zoo import SupervisedProblem\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"id": "4cf5b181",
"metadata": {},
"source": [
"## Data Generation\n",
"\n",
"We will focus on solving a specific PDE: the **Darcy Flow** equation. This is a second-order elliptic PDE given by:\n",
"\n",
"$$\n",
"-\\nabla\\cdot(k(x, y)\\nabla u(x, y)) = f(x, y), \\quad (x, y) \\in D.\n",
"$$\n",
"\n",
"Here, $u$ represents the flow pressure, $k$ is the permeability field, and $f$ is the forcing function. The Darcy flow equation can be used to model various systems, including flow through porous media, elasticity in materials, and heat conduction.\n",
"\n",
"In this tutorial, the domain $D$ is defined as a 2D unit square with Dirichlet boundary conditions. The dataset used is taken from the authors' original implementation in the referenced paper."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2ffb8a4c",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-19T13:35:28.989631Z",
"start_time": "2024-09-19T13:35:28.952744Z"
}
},
"outputs": [],
"source": [
"# download the dataset\n",
"data = io.loadmat(\"Data_Darcy.mat\")\n",
"\n",
"# extract data (we use only 100 data for train)\n",
"k_train = torch.tensor(data[\"k_train\"], dtype=torch.float)\n",
"u_train = torch.tensor(data[\"u_train\"], dtype=torch.float)\n",
"k_test = torch.tensor(data[\"k_test\"], dtype=torch.float)\n",
"u_test = torch.tensor(data[\"u_test\"], dtype=torch.float)\n",
"x = torch.tensor(data[\"x\"], dtype=torch.float)[0]\n",
"y = torch.tensor(data[\"y\"], dtype=torch.float)[0]"
]
},
{
"cell_type": "markdown",
"id": "9a9defd4",
"metadata": {},
"source": [
"Before diving into modeling, it's helpful to visualize some examples from the dataset. This will give us a better understanding of the input (permeability field) and the corresponding output (pressure field) that our model will learn to predict."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c8501b6f",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-19T13:35:29.108381Z",
"start_time": "2024-09-19T13:35:29.031076Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAEjCAYAAAARyVqhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8ekN5oAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2+UlEQVR4nO3dC3xU5Zk/8GcumUmAJIAEkigXQQFFLoqFBqFIoSDrolBLkbUlUKW7Luzqhw/Wxg83LzVV1kstLFi3iK4VkFZhu7psEQXKAiograyWJRRIIgkkgdwvczv/z/PynzETMpn3wZzJmcnv+/kcwsy8c/LO5Tx5znvOeV6bYRgGAQAAAFiYvaM7AAAAABANEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAAFgeEhaIW6tWrSKbzUbl5eVR2w4YMIDmz58fur179271XP4ZxI9zO4BE98knn9C4ceOoa9euajs4evRoaHu6ErrbzunTp9Xv2LhxI8UK/y7+nfy729Ptt9+uFogdJCwAEdTX16sg3jypAYh3Xq+XZs+eTRcuXKAXXniB/v3f/5369+/f0d2ypM8//1zFgPZOduDKOK/weQBx5fjx42S3t52fv/LKKxQIBMISlscff1z9H3tSkChOnjxJZ86cUd/3Bx54IHT/smXL6Kc//WmH9s2KCQvHAN7+W44g/eEPf+iwfnVWSFggKv7D3aVLF4pnbrc7apukpKSY9AWgI50/f1797N69e9j9TqdTLaDH5XJ1dBc6HRwSinPB485/+ctf6Pvf/z6lpaXRVVddRQ899BA1NjaGtX3jjTdo9OjRlJKSQj179qR7772XioqKwtrwnsRNN91Ehw8fpm9961sqUXnsscdCx57/5V/+hdauXUsDBw5Uj02dOlWtgyf9fvLJJ+maa65R67/77rvVkHNL//Vf/0UTJkxQx85TU1PpzjvvpP/93/8Na/PnP/9ZHRPn35GcnEyZmZn0ox/9iCoqKlp9D/gclmivveU5LNGOw/PrzcjIUP/nPSx+7bzw+/3qq6+q/3/66aeXrePpp58mh8NBX375ZZu/C6Aj8Hd84sSJ6v98WIi/x8HRw0jnsOjEjdZUVlaq35eenq6So9zcXHWf7mEr3u6uv/56FQN4ux4/fjzt3LkzrN0HH3wQiif8OzjufPHFF1HXH9yWW2oeJ/jcF36P2KRJk0IxIHiIuLVzWDgZvP/++6lPnz6q3yNHjqTXXnstrE3zWPqrX/2KBg0apHaovvGNb6hziyAypNMJgv9g88aWn59PBw8epJdeeokuXrxIr7/+unr8Zz/7GS1fvly142HgsrIy+uUvf6mSEv7D23xvixOD6dOnq8D0gx/8QG18Qb/5zW/I4/HQP/3TP6mE5Nlnn1Xr/Pa3v6025EcffZQKCgrUupcuXUobNmwIPZePlXPQmjZtGj3zzDNq5GbdunUqEHEfgskCB6W//vWvtGDBApWscELDGzb/5NfWMqhGe+1XgpMV7tuDDz5Is2bNou9+97vq/hEjRtC1115LixYtUu/FzTffHPY8vo+D2NVXX33FvxvALH//93+vvpucWP/zP/+z+iPZfPtuSRI3muMdGE4e9u3bR//wD/9AN9xwA73zzjtq+9fByQRvz/w7x4wZQ9XV1XTo0CE6cuQIfec731Ft3n//fRWneMeG2zc0NKi+3Xbbbard1z2Bnl8jv0ccT3injV8DC/5siX8/b/sc/xYvXqzixNatW1UCxIka70g19+abb1JNTY36TDimcSzlOMOxD6O9ERgQ11auXGnwx3jXXXeF3f+P//iP6v4//elPxunTpw2Hw2H87Gc/C2vz2WefGU6nM+z+iRMnquetX78+rO2pU6fU/RkZGUZlZWXo/ry8PHX/yJEjDa/XG7p/7ty5hsvlMhobG9Xtmpoao3v37sbChQvD1ltaWmqkp6eH3V9fX3/Z69y0aZP6PXv37hW99qD+/fsbubm5odsffvihasM/g/hxbhdUVlam2vDvaYlfX3Z2tuH3+0P3HTlyRLV/9dVXL2sPYBXB7/7WrVvD7g9uT0GSuNFy29m2bZta17PPPhu6z+fzGRMmTNDaRjie3HnnnW22GTVqlNG7d2+joqIidB9v83a73Zg3b17oPv5d/Ds5hgVF2q5bxgl+j1rGieaxkpegF198UbV94403Qvd5PB4jJyfH6Natm1FdXR0WS6+66irjwoULobbbt29X9//+979v83V3ZjgklCB4j785HgFh7733Hr399tvqZFLeS+LDJ8GFRy94yPXDDz8Mey4PT/LoRmt4iJSHeIPGjh2rfvJITPPj33w/j8QED43wqAnvZcydOzesD3z4hNs27wMPPQfxoR1u981vflPd5j0nyWs3y7x58+js2bNh/ebRFe77PffcY9rvBYgVadxojrc9jgc8QhnE23pw24yGR254RPXEiROtPl5SUqIuxebRCz5MFcQjoDwCY+a2Hwn/Tn5vOMYF8UgJj9LU1tbSnj17wtrPmTOHevToEbrNh7YYj7BA63BIKEFwAGmOj4vyVTF8vJR/8k5FyzZBLYcfecg40gll/fr1C7sdTF769u3b6v18aIYFAw8fOmoNn38SxIea+Pj15s2bQycIBlVVVYleu1k4KGZlZakkZfLkySqwb9q0SQ2D87k5APGOt1lJ3GiOr0Li7aNbt25h9w8ZMkTrdz/xxBNqWxo8eLA6p+6OO+6gH/7whyohCa4/0vr4kM1///d/U11dnTq3JVa4T/xetbwaMXgIKdjnSLE0mLwEYyZcDglLgmp+ngf/MeXbfMIr7+W01DKoNB/haKm157d1/6XR10t9CJ7HwnshLTUfneE9uv3799MjjzxCo0aNUv3j53PQan7ZcSRXWvxKgl/v3/3d36lLQ//1X/+V/ud//keNuPBIE0AikMaN9sTnj/Dl19u3b1eXD//bv/2bqhmzfv36sEux25vf76dYiRYz4XJIWBJob4hP8griE7844PCJZ7xh8EbAj/MeS0fgUQ/Wu3dvmjJlSsR2vHexa9cuNcKyYsWK0P2RhoajvfavI1riw4eFnnvuOfr973+vgjqfqMsnFAMkAt5mrzRucCE63o75UEjzxIbrIeniQz18aJoXXg8nMXxyLScswUJ3ra2Pr5js1atXm6MrPJrR8oolPoTNh5qudOeH+8RXOHLsaT7Kwv0JPg5fD85hSRB8qXFzfLY847Po+cxzTlo4CWiZvfPtSJcLtyf+Q86HffjqBL5ksSW++qD5XkfLfr744otX9Nq/jmDtmUiXYvLwNC+89/e73/1OXVWFOhaQKL5O3Pibv/kb8vl86kq75qMXwW0zmpbr5qTnuuuuo6amJnWbDzfx6CtfMtx8+zx27JgakeHfHy0Z27t3b9h9fCViyxGWYNKjczk2/87S0lLasmVL6D5+D/g1c/+Dl5PDlUN0TRCnTp2iu+66Sx02OXDggKqdwIcsuA4Ae+qppygvL0+d1zFz5kx1ngU/hy81/PGPf6wuQTYTJyscvPg49C233KL+uPOIRGFhIb377rvqUsQ1a9aodrwnxZf4cWLD59NwAOK+Xulrv1J8aOzGG29UAYj3MHmPj4+n89J8lCX43uFwECQS/qN+pXFjxowZapvmyrn8XN6O+CTe1s5Baw2350uEuf4Lb3d8SfNvf/tbdblw0OrVq9VOSU5Ojqp9Erysmc+fa63GSnM8SsOXW/MJ8nw+2p/+9Cd13guPzDTHSREnbVyGgfvOFyTweXg8UtwSvx8vv/yyOhGY61jxCC/3mQ8X8w4Xzm1rBx19mRJ8PcFLET///HPje9/7npGammr06NHDWLx4sdHQ0BDW9ne/+50xfvx4o2vXrmoZOnSosWjRIuP48eOhNnyZ3rBhwy77PcFL8VavXq11iWTwUsJPPvnksvbTpk1TlzInJycbgwYNMubPn28cOnQo1Ka4uNiYNWuWugya282ePds4e/bsZZciSl77lVzWzPbv32+MHj1aXaLd2qWQJSUl6tLPwYMHX/aeAcTzZc2SuNHatsOXG//whz800tLS1HbM///000+1Lmt+6qmnjDFjxqgYkJKSon4nX0bNlwk39/777xu33XabasO/Z8aMGSoeNNfaZc1cjuDRRx81evXqZXTp0kXFpIKCgsviBHvllVeMgQMHqu28ecxoeVkzO3funLFgwQK1Xo4Zw4cPv+y1RoqlbV1uDZfY+J/2SHygY/CeBA/Z8iGVlnsHYD6+zJOHp/l8Gy6wBQAA5sA5LABfA5fv5uPefKgLAADMg3NYAK4Az2HCM7ly6XI+tv91r0gCAIC2IWEBuAJc2IprxfCJhbpXPgAAwJXDOSwAAABgeTiHBQAAACwPCQsAAABYXkKcw8KlkHkeFy7ME4t5ZADgcnx0uaamhrKzsy+bAM6qEDsA4iduJETCwgGn5WzBANAxioqK6JprrqF4gNgBED9xw7SEhed34dLJPLcCl0jnKynGjBkTsf3WrVtV4S0u48xTdHMp5GjzQQQFSx5fs3IZ2ZOTo7a3RZ/wN0zXQsHeomAnrfqGy+fUicRdEnkq96+ry3n9866TamXnaKeU6b9Gsuu/eQGHbG+4LlP/q558Uf8L4mzQb2sTnt/uqrw0b4qOQFLrM7+2xlkQPsFbNP7ycq12PvLSPnrva5Ugj2XcYMG+njkygNK6Rd/Oy/x1JFHu148dFwKRZ0lv6aI/8sR+LVUJ1qvW7dOfhbnCq9+Pco9sdufyJv325Q2X5v3ScaFGvy3zVEX/mxLkrNSPM+4K/RiWXCGMu+X6s06nnKvXbusouSDqh6/0XLvGDVMSFp57ZcmSJWoq8LFjx6p5FHjyO55Zs7U5GPjy0Llz51J+fj797d/+Lb355puqtsWRI0fC5m2JJDiUy8mKGQmLw21OwmJP0f8j43Cbl7A4VNX59m/LnE6HOQmLU5awOFz6X3VnkiBh8QkSloDwvRMkZQHB++y0u0T9sNk0v3v//+Vd6aGVWMeN5n3lZCUtNfp23ihIQKTtmwKCtn79z9vjl4X5Bp/gj65XPy65PLLvXZJTv73T7tZu6/DrJyDM7tFvb2/Uf+8cbpt5cTdJP2FxOvTbOoSxg3RihyBumHKg+fnnn6eFCxeqacF5EisOQDzz7YYNG1pt/4tf/EJNXPfII4/QDTfcQE8++aSaII8nwwOAzgFxAwBimrB4PB41U+WUKVO++iV2u7rNM+m2hu9v3p7xnlWk9jzFeHV1ddgCAPErFnGDIXYAxC+7GZPB8dwqffr0Cbufb/Nx6dbw/ZL2PATMU4gHF5w0BxDfYhE3GGIHQPyKj2sPW8jLy6OqqqrQwmcXAwBEg9gBEL/a/aTbXr16kcPhoHPnws8O5tuZmZmtPofvl7R3u91qAYDEEIu4wRA7AOJXu4+wuFwuGj16NO3atSusOBPfzsnJafU5fH/z9mznzp0R2wNAYkHcAIAOuayZL03Mzc2lW2+9VdVQ4MsT6+rq1Nn/bN68eXT11Ver48nsoYceookTJ9Jzzz1Hd955J23evJkOHTpEv/rVr8zoHgBYEOIGAMQ8YZkzZw6VlZXRihUr1Alwo0aNoh07doROkCssLAwrwTtu3DhVQ2HZsmX02GOPqQJQ27Zt066lEORssJHdiH4tt7NWVieiZpB+rY3kvjX6K67QL2DUmO2TFYMrFNQd0a8bRIEk4XvXT/+6fY2P7qu2gvIu0to7Vdfqr9wm6Ei3Yv16B6ypu35RrtST+t874+oMUT9smlfS2Aw7kX6tO8vEDYA2mTVjg1VmgrBJiod17GmvNoML+cc5vjSRz/gfuOJnWoXjpAlLYx9zEpZ6QcJCAZtpCUu3IsO0RIH/hsVbwuJJ0++ITbD1SBMWSSCRJCw2v2yTN744qdXOZ3jpw6a31MmsaWlpFE+x4+L/DdQqHHdeWOm2TFA4rkJQkfaCX78KbKVfVtn1gqDSbblXv22ZR1YBuaxRsO4G/eS+olq/LWuSVLq9IKioLal0WybbZruUCSrdlppY6fbLs9HbGF7aTdu14kZcXiUEAAAAnQsSFgAAALA8JCwAAABgeUhYAAAAwPKQsAAAAIDlIWEBAAAAy0PCAgAAAJaHhAUAAAAsDwkLAAAAWB4SFgAAAOiccwl1FE+Gj+wp0efc8fSRlTm+YVD08sJBxVXp2m3Tetdqt236c3cyqyR+bV/9EtF22ZRGIgFBuf2A/hRFij9ZMP2AoN6+ZIqAgFM2n0CXc/r9aOyjX369y8mLon7Yuut9p+0BD9F50aohBvySYMDtBZPc+AX7vAHJ3BsWmTrn0hMEfy9sJk1bIh1asEnaWuNz0YERFgAAALA8JCwAAABgeUhYAAAAwPKQsAAAAIDlIWEBAAAAy0PCAgAAAJaHhAUAAAA6X8KSn59P3/jGNyg1NZV69+5NM2fOpOPHj7f5nI0bN5LNZgtbkpOT27trAGBRiBsAEPOEZc+ePbRo0SI6ePAg7dy5k7xeL02dOpXq6urafF5aWhqVlJSEljNnzrR31wDAohA3ACDmlW537Nhx2V4Q7zEdPnyYvvWtb0V8Hu8dZWZmtnd3ACAOIG4AQIeX5q+qqlI/e/bs2Wa72tpa6t+/PwUCAbrlllvo6aefpmHDhrXatqmpSS1B1dXV6qfN5VdLNH16X+qTria//ts0oId+2fNjRwdot7V1k00nQKTf3u7RL81sM7PcfkpAu62REv1zbs6hMWVDkNOpv25Pvf4cAXXJss3N7tF/85IvCsqpp6eI+uGo0PxOB2SfSazjRluxI5FJyuezgGGNcvsBcbTRY5OU2ldP0G9q2AXTethtph0LMQRx14ijMv6mnnTLQeThhx+m2267jW666aaI7YYMGUIbNmyg7du30xtvvKGeN27cOCouLo54vDs9PT209O3b18RXAQCxZFbcYIgdAPHL1ISFj0kfO3aMNm/e3Ga7nJwcmjdvHo0aNYomTpxIb7/9NmVkZNDLL7/cavu8vDy1BxZcioqKTHoFABBrZsUNhtgBEL9MOyS0ePFi+s///E/au3cvXXPNNaLnJiUl0c0330wFBQWtPu52u9UCAInFzLjBEDsA4le7j7AYhqGCzjvvvEMffPABXXvtteJ1+P1++uyzzygrK6u9uwcAFoS4AQAxH2Hh4dw333xTHVfmmgqlpaXqfj5enJJy6WQ/Hsa9+uqr1fFk9sQTT9A3v/lNuu6666iyspJWr16tLk984IEH2rt7AGBBiBsAEPOEZd26dern7bffHnb/q6++SvPnz1f/LywsJLv9q8Gdixcv0sKFC1WQ6tGjB40ePZr2799PN954Y3t3DwAsCHEDAGKesPDQbjS7d+8Ou/3CCy+oBQA6J8QNAIgGcwkBAACA5SFhAQAAAMtDwgIAAACWZ3pp/lhypXjJ0SV6TeK+qZWi9doFpZw/+uw6/RW7BaXoHcJy0j79EsoBSaV2p6wfdkFJ/G7dGrXbZnRre1K8y9qn1JIZSuvStNueKeolWrenu3597UCSYHoFr/73DmJHUhLfI6i9Li2f7zepNL9ZpfaZrLq8LIbZBLFXVBJfMm2J8C91wKH/jhgOwbhFIpfmBwAAAGgPSFgAAADA8pCwAAAAgOUhYQEAAADLQ8ICAAAAloeEBQAAACwPCQsAAABYHhIWAAAAsDwkLAAAAGB5SFgAAADA8hKqNL8RsFEgEL108PHy3qL11hSnmVO6PklQIl3jdYVx6a/b7vJrt03u4hF1IyNVv4T+oLRy7bZDu5WI+pGdpD8dQ13Ard32I8dA7balqakkYQu4tNv6BaX57XWNsn449cKELYCS/y35BUXjJW0lZfy9hizMewU14/2CfV6/cIoAs4iry9sFpfkF8V9Sbj/glHXaEJXmF6zbjtL8AAAAAG1CwgIAAACdL2FZtWoV2Wy2sGXo0KFtPmfr1q2qTXJyMg0fPpzee++99u4WAFgY4gYAdMgIy7Bhw6ikpCS07Nu3L2Lb/fv309y5c+n++++nTz/9lGbOnKmWY8eOmdE1ALAoxA0AiHnC4nQ6KTMzM7T06tUrYttf/OIXdMcdd9AjjzxCN9xwAz355JN0yy230Jo1a8zoGgBYFOIGAMQ8YTlx4gRlZ2fTwIED6b777qPCwsKIbQ8cOEBTpkwJu2/atGnq/kiampqouro6bAGA+GZ23GCIHQDxq90TlrFjx9LGjRtpx44dtG7dOjp16hRNmDCBampqWm1fWlpKffr0CbuPb/P9keTn51N6enpo6du3b3u/DACIoVjEDYbYARC/2j1hmT59Os2ePZtGjBih9nj4RLjKykp666232u135OXlUVVVVWgpKipqt3UDQOzFIm4wxA6A+GV64bju3bvT4MGDqaCgoNXH+Vj1uXPnwu7j23x/JG63Wy0AkJjMiBsMsQMgfpleh6W2tpZOnjxJWVlZrT6ek5NDu3btCrtv586d6n4A6JwQNwDA9IRl6dKltGfPHjp9+rS69HDWrFnkcDjUJYhs3rx5alg26KGHHlLHrZ977jn6y1/+ouoxHDp0iBYvXtzeXQMAi0LcAICYHxIqLi5WQaaiooIyMjJo/PjxdPDgQfV/xmf+2+1f5Unjxo2jN998k5YtW0aPPfYYXX/99bRt2za66aabxL/bMC4t0dScSRet1+7VbxtIFqzY5zBljgrFod/eKZhLqEfXBlE3hnQPH7ZvS07aSe22tyRHvoKkNRl2n3bbIr/+IYPCpqu029oFc5IwydQrDo/w+yEQqKzSa2fI5pmyStwwU0DwIUrm/PGI5vuRzf8SELT3BRymzH90qb0589bY7bI5r2yC7TYgiLuSmG4I5xIKCP6yG07B59JsG0yIhGXz5s1tPr579+7L7uOT7XgBgM4JcQMAosFcQgAAAGB5SFgAAADA8pCwAAAAgOUhYQEAAADLQ8ICAAAAloeEBQAAACwPCQsAAABYHhIWAAAAsDwkLAAAAGB5ps/WHEuuT7uRwx29Nn5Skmy9AcHkrka9oIyzoB8Bl6z0ut9mTqn2rkmy8uuZ7mrttgNd5/XbCr+53ezdtNvWG7XabR02/TLffr9s/8ApmBLCFhCU+U6WbQABj15HAob+9Afxyq8z90czHsE+oaTcvqSMv1ewXtVeUm5fUMZfWmrfrNL8UpLS/JIhAMHbLCq1zwKCUv6GQ/A+2zr2M8EICwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACdL2EZMGAA2Wy2y5ZFixa12n7jxo2XtU1Ojl5LBQASC2IHAMS0cNwnn3xCfr8/dPvYsWP0ne98h2bPnh3xOWlpaXT8+PHQbQ48ANC5IHYAQEwTloyMjLDbP//5z2nQoEE0ceLEiM/hIJOZmdneXQGAOILYAQAdVprf4/HQG2+8QUuWLGlzz6e2tpb69+9PgUCAbrnlFnr66adp2LBhEds3NTWpJai6+lL5d65ArVOF2l0pK69t2G2mlESWlPz3pcj2HCVVrb3J+l+DWq9L1I96v377esEbUm/UifpBgUbtpmX+FO22FR79kv8+j6xEuktQ6d7u1f9O2xpk0yvYXXql/O1ctl7/bbZU7ND11fiPnoChf9Q9IDhC7xe09QrruktK+fsFgUZSxt9MNjPbC8r4G6aW5ieTyvh37Gmvpv72bdu2UWVlJc2fPz9imyFDhtCGDRto+/btKkBx4Bk3bhwVFxdHfE5+fj6lp6eHlr59+5r0CgCgIyB2AEBME5Zf//rXNH36dMrOzo7YJicnh+bNm0ejRo1SQ79vv/22Ghp++eWXIz4nLy+PqqqqQktRUZFJrwAAOgJiBwDE7JDQmTNn6P3331dBRCIpKYluvvlmKigoiNjG7XarBQASD2IHAMR0hOXVV1+l3r1705133il6Hl8l8Nlnn1FWVpZZXQMAC0PsAICYJSx8LJmDTm5uLjmd4YM4PITLw7JBTzzxBP3hD3+gv/71r3TkyBH6wQ9+oPawHnjgATO6BgAWhtgBADE9JMTDuYWFhfSjH/3ossf4frv9qzzp4sWLtHDhQiotLaUePXrQ6NGjaf/+/XTjjTea0TUAsDDEDgCIacIydepUMvgSx1bs3r077PYLL7ygFgAAxA4AiARzCQEAAIDlIWEBAAAAy0PCAgAAAJaHhAUAAAA691xCsZZ22k/OJOlsH+08l5DetCuKz62/Xns32QwYtoB+Ltro0J/vp8ytP3cO+yJZf2K6dGeDdluPZCIOIkq2ebXbft50tXbbgppe2m0DdYIvB3/mkil/JNNjNbvSRoctWa/Qmo3nlWmnuYSsSjBlk+IXzEQj+U5L5vuRzuETEMwP5AsI+iGZ4EzYXvixiNhsgrVLXqJovcK55Ozm/H2T9qO9YYQFAAAALA8JCwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAAFheQpXmT77oI6fTF7VdwCkscy+ooCxZtz9ZP190Nsn6bPcKclFBGf8mo4uoH38J9NFuW92UrN32825Zon6kOPRL859rTNVue7q8p3ZbZ7VsOgGHxzDlO0qSUtwsSXPqhoBstZ2BX7BPGDDMaeuX1GmX9llQi15amt8Qttder7S9Sf2QlPGXdsGQfOSi6QRQmh8AAACgfROWvXv30owZMyg7O5tsNhtt27Yt7HHDMGjFihWUlZVFKSkpNGXKFDpx4kTU9a5du5YGDBhAycnJNHbsWPr444+lXQMAi0LcAICYJyx1dXU0cuRIFSha8+yzz9JLL71E69evp48++oi6du1K06ZNo8bGyNO4btmyhZYsWUIrV66kI0eOqPXzc86fPy/tHgBYEOIGAMQ8YZk+fTo99dRTNGvWrMse472kF198kZYtW0Z33303jRgxgl5//XU6e/bsZXtUzT3//PO0cOFCWrBgAd14440qaHXp0oU2bNggf0UAYDmIGwBgqXNYTp06RaWlpWo4Nyg9PV0N1R44cKDV53g8Hjp8+HDYc+x2u7od6TlNTU1UXV0dtgBAfIpV3GCIHQDxq10TFg46rE+f8CtD+HbwsZbKy8vJ7/eLnpOfn68CWnDp27dvu70GAIitWMUNhtgBEL/i8iqhvLw8qqqqCi1FRUUd3SUAiAOIHQDxq10TlszMTPXz3LlzYffz7eBjLfXq1YscDofoOW63m9LS0sIWAIhPsYobDLEDIH61a8Jy7bXXqmCxa9eu0H18jJjP+s/JyWn1OS6Xi0aPHh32nEAgoG5Heg4AJA7EDQAwpdJtbW0tFRQUhJ0wd/ToUerZsyf169ePHn74YXU1wPXXX68C0fLly1XthZkzZ4aeM3nyZHW1wOLFi9VtvjQxNzeXbr31VhozZoy6YoAvg+Sz/wEg/iFuAEDME5ZDhw7RpEmTQrc5aDAOHBs3bqSf/OQnKmj8+Mc/psrKSho/fjzt2LFDFXYKOnnypDppLmjOnDlUVlamCkfxCXOjRo1Sz2l5Ql00SVVN5NSofh5IThKt1+bXrzluOARlrRv02zoapWXdJW31yy07PLJBuab6FO22RZWaJeCJqLhrD1E/7En6n2HAI3ivq/U3oeRq4fQKgtL8ktLdRpLsu2Tvpjcdgy3gIPpqs46buCEhnX1AVEJfMODtl5TEF9Vel5fQN2u9ZpX9l5baN8xqLJ0jQEA0VUccsRlcBCHO8fAxn/E/6eafktPhjp+ExaXf1pci+yPj7abfvilNfwNu6iHb2Jt66H+9vN392m1tXaPPGWW5hOW87DPsUqL/3nU7q/9+pHxZI+qHvaZBq50v0ETvn16jTmaNl3NDgrHj4v8NpLTU6NvjSW+taP1FPv334UuffhJe5tOf7+qir6t2W9Xeqz9fWI1Pf/6vKo9+W7Vur377Wo/+zk5tY/S/Ec01NOiv21+r/7fFIZhbzHVRtqPY5bwgdnypHzu6nLwo6of/+FejqpH4DC/tpu1acSMurxICAACAzgUJCwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAACTeXEKJwFHVKHyCOXNr2AVzutibhHMJNel/tM4G/XUn1clyXHel/nvn7abfZ1+K7KsbcJkzx4ddMGdTkqyqu2g+KMncIf5ustLkZNf7zAP+ThlOLM8vmM9IPP+RYF4eyXrNnB9IOpeQaKKugH5bm6itfhesNKdRe8MICwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAgMRLWPbu3UszZsyg7OxsstlstG3bttBjXq+XHn30URo+fDh17dpVtZk3bx6dPXu2zXWuWrVKrav5MnTo0Ct7RQBgOYgbABDzhKWuro5GjhxJa9euveyx+vp6OnLkCC1fvlz9fPvtt+n48eN01113RV3vsGHDqKSkJLTs27dP2jUAsCjEDQD4usSVnqZPn66W1qSnp9POnTvD7luzZg2NGTOGCgsLqV+/fpE74nRSZmamtDsAEAcQNwDA8uewVFVVqaHa7t27t9nuxIkTaih44MCBdN9996lAFUlTUxNVV1eHLQCQOMyIGwyxAyB+mVpLu7GxUR2bnjt3LqWlpUVsN3bsWNq4cSMNGTJEDes+/vjjNGHCBDp27BilpqZe1j4/P1+1acle7yG7Thl9m7A0s7Qssiabx6fftlFWmt/eoL9uZ52gNH+N7CvjT9Zfty9FUBLcLfsM/S5BCXHZW63Nrv+RKEn1+l88b1f99y7glJXmd9n13jufL2DpuNFW7IArIym3HyCbaaX5/QFBPwQl8VV7v6Q0P5lSbt/u12+r1u03Z1oP0owFcTfCwifSff/73yfDMGjdunVttuWh4tmzZ9OIESNo2rRp9N5771FlZSW99dZbrbbPy8tTe2DBpaioyKRXAQCxZGbcYIgdAPHLaWbQOXPmDH3wwQdt7iW1hoeBBw8eTAUFBa0+7na71QIAicPsuMEQOwDil92soMPHlt9//3266qqrxOuora2lkydPUlZWVnt3DwAsCHEDANo9YeGgcPToUbWwU6dOqf/zyW4cdL73ve/RoUOH6De/+Q35/X4qLS1Vi8fjCa1j8uTJ6iqAoKVLl9KePXvo9OnTtH//fpo1axY5HA51DBsA4h/iBgDE/JAQB5VJkyaFbi9ZskT9zM3NVYWc/uM//kPdHjVqVNjzPvzwQ7r99tvV/3kvqLy8PPRYcXGxCjIVFRWUkZFB48ePp4MHD6r/A0D8Q9wAgJgnLBw8+IS4SNp6LIj3iJrbvHmztBsAEEcQNwDg68JcQgAAAGB5SFgAAADA8pCwAAAAgOUhYQEAAIDOXZo/1mweL9ns0XMwI0VYOMrjFXRCULpY40TDIHuToA+8akkJ5Tr9vNXhlNWtN5IcprQNuGVf3YBLUPZf0NZwmleq2ttV//3wJ+mv19koLKGv+52WTnkBnYqk1D4zRKX5BW39sv10Q1Ca3+YTtJX8WfGZNw2ILSCoze83aZ4aTRhhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWl1Cl+Rv79SSnMzlqO/eXVbIVu/TrntsamsgUfr+oua1BVsrfLDpTJYQIyv7bk5ymfYaGS3/dAZd+nz09hFNCCKZucFcJymsL6ZbuFpX4hk5Xbl9amt8XsJvSNiAszU9e/fZ2r82kttpNL7X362+LNp9guw2gND8AAABAm5CwAAAAQOIlLHv37qUZM2ZQdnY22Ww22rZtW9jj8+fPV/c3X+64446o6127di0NGDCAkpOTaezYsfTxxx9LuwYAFoW4AQAxT1jq6upo5MiRKlBEwoGmpKQktGzatKnNdW7ZsoWWLFlCK1eupCNHjqj1T5s2jc6fPy/tHgBYEOIGAMT8pNvp06erpS1ut5syMzO11/n888/TwoULacGCBer2+vXr6d1336UNGzbQT3/6U2kXAcBiEDcAwJLnsOzevZt69+5NQ4YMoQcffJAqKioitvV4PHT48GGaMmXKV52y29XtAwcOtPqcpqYmqq6uDlsAIL6ZHTcYYgdA/Gr3hIWHdV9//XXatWsXPfPMM7Rnzx61Z+WPcFlueXm5eqxPnz5h9/Pt0tLSVp+Tn59P6enpoaVv377t/TIAIIZiETcYYgdA/Gr3Oiz33ntv6P/Dhw+nESNG0KBBg9Te0+TJk9vld+Tl5alj10G8l4TAAxC/YhE3GGIHQPwy/bLmgQMHUq9evaigoKDVx/kxh8NB586dC7ufb0c6ns3HutPS0sIWAEgcZsQNhtgBEL9MT1iKi4vVseisrKxWH3e5XDR69Gg1FBwUCATU7ZycHLO7BwAWhLgBAF87YamtraWjR4+qhZ06dUr9v7CwUD32yCOP0MGDB+n06dMqeNx999103XXXqcsNg3iId82aNaHbPET7yiuv0GuvvUZffPGFOuGOL4MMnv0PAPENcQMAYn4Oy6FDh2jSpEmh28Hjwbm5ubRu3Tr685//rAJIZWWlKhI1depUevLJJ9VQbNDJkyfVSXNBc+bMobKyMlqxYoU6YW7UqFG0Y8eOy06oi8ZVUU9OR/Q5d/w9uojW67hYr93WSHZpt7XVNeh3wi+bw8HwCeYe8nr025o5X4xdMNeIQ38OH2ZzCT4XwbxD1L2bdlOHR7BeIkopE3zmgrfOlyx772r7Juut1xufccPMPTy7Tf8zdJCkrf526BD0wUzSuYT8knmKRHMJyfpBgvY2k+YHskvm+yFuL2jr1f9+2CR/V0xgMwzBDGsWxSfO8Rn/377pEXI6ok8wF+ii/8dLmrCQw26NhKVJkIQgYQknSFgCgoTF2zOFRCRvtYkJiyddr73P20iHf7uMqqqq4ubckGDsuPh/AyktNfq2e8pbK1p/kV//+3HW20O7bZlP//0t9+n3gVV49Ntf8Ojv/F1oku0oVjXpJcqspkG/bUOdLP4HavXjgbNKf9tyVelvtO5KWdxNqdBv36VEf8JeV1HkUgOt8Z0ujN7G8NJu2q4VNzCXEAAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAACReaX4rM5IcZGhUQLX5ZFVjJaX8HRfqtNsakhLw1bIKmxTQL6FsePXrOBs+n7Af5lTGtQkqCjOjUb+aoz0tVb8fHv33w3Ve/7vBAi79zdNzlX6lTylvil5FTr9DWPI8Dpn5EiVl/K0iICixbEhL8wvK7fv9gtL8PlnssPkE5fYFletFbYVh1y4o5S/6eyisuN7eMMICAAAAloeEBQAAACwPCQsAAABYHhIWAAAAsDwkLAAAAGB5SFgAAADA8pCwAAAAQOIlLHv37qUZM2ZQdnY22Ww22rZtW9jjfF9ry+rVqyOuc9WqVZe1Hzp06JW9IgCwHMQNAIh5wlJXV0cjR46ktWvXtvp4SUlJ2LJhwwYVSO6555421zts2LCw5+3bt0/aNQCwKMQNAIh5pdvp06erJZLMzMyw29u3b6dJkybRwIED2+6I03nZcwEgMSBuAIClS/OfO3eO3n33XXrttdeitj1x4oQaLk5OTqacnBzKz8+nfv36tdq2qalJLUHV1dXqZ2NGCjmTopcod1/QL9PO7PUe/cbO6FMDXAmbyyVqbzQ06q9bYzqD0Hqbve9a7SWl+QXTCRDJ3g+74P0zvF79FZdd0G5qc7tJJPsqU0pxV/eXbfZJDXrrtnkNS8eNtmKHWRwUiLsy/pJy+wFBuX3Jepk/oN/e7xesW9JWWJrfJghhZrVldsG2aJeU5jfMmWrFEifdcsBJTU2l7373u222Gzt2LG3cuJF27NhB69ato1OnTtGECROopqam1fYclNLT00NL3759TXoFABBrZsUNhtgBEL9MTVj4OPR9992n9n7awkPFs2fPphEjRtC0adPovffeo8rKSnrrrbdabZ+Xl0dVVVWhpaioyKRXAACxZlbcYIgdAPHLtENCf/zjH+n48eO0ZcsW8XO7d+9OgwcPpoKCglYfd7vdagGAxGJm3GCIHQDxy7QRll//+tc0evRodWWAVG1tLZ08eZKysrJM6RsAWBPiBgC0W8LCQeHo0aNqYXzcmP9fWFgYdiLb1q1b6YEHHmh1HZMnT6Y1a9aEbi9dupT27NlDp0+fpv3799OsWbPI4XDQ3Llzpd0DAAtC3ACAmB8SOnTokLrcMGjJkiXqZ25urjoBjm3evJkMw4gYOHgvqLy8PHS7uLhYta2oqKCMjAwaP348HTx4UP0fAOIf4gYAfF02gyNEnOM9Mz7jP2fq4x1+WbNNcomYgK2mXtTeqDLncs1Afb0lLmu2JQkva06J/r0IcSWRGaSXNfsFlzX7uur3uXKQ25TLmv3eRjr81jJ1MmtaWhrFU+y4+H8DKS01+oBzoa9WtP4iXxfttl/6emi3LfPpv7/nvbLPosyTqt22vKmrdtuLTfrvBbtQn6LdtqZOf/v21gpjR43+fn1Sjf4l0K4q/bbui7I/013K9GNpSmmDdltHiX4ZB+Yr/jJ6G8NLu2m7VtzAXEIAAABgeUhYAAAAwPKQsAAAAIDlIWEBAACAzj2XUKzZvQGyG9FPevV2k51UaU/Rf5tcpZHLgrdkaxTMUSSZ34bX3SNdu23gXJl2W8Mvm9TCLjjR1PDrz2lkc8hybX9tHZnB0U3/pEOjZ3fRum1e/ffaL3if7T5RN4h0z/eL+9P3O5bDpDfQLlyvtL0Z8w4xQ9LerLaqPZnUD/2m4qmjDEpIGGEBAAAAy0PCAgAAAJaHhAUAAAAsDwkLAAAAWB4SFgAAALA8JCwAAABgeUhYAAAAwPKQsAAAAIDlIWEBAAAAy0uISreGcamsn8/XZMr67X79soF2v34fbAFBpVtJW9Vev88BQ3/dAUNWcdcuqPwY/Bx12ITVKgOGtLyrHkPw3pHgu6HWLSgq7PPp73v4PbJqxeTV+1z83kbx59jRgn2trtUrJVrjk5UcrRO0r/fpfy4Nfv3vc5OwSrbHo9/e69H//vsaZX9u/A363+lAvf56Aw3CsrGN+hW4/Y12U8KB3yPbpnyCKtk+/6XtVocRkMUwn8bfCx95teNGQiQsNTWXyuF/9Mefd3RXoCX9bUFGFoPNU21S2zjfHtPT9aeGsELs6H/L6Y7uCkCnVqMRN2xGPO0ORRAIBOjs2bOUmppKNttXe97V1dXUt29fKioqorS0NEo0if76OsNrTKTXx6GEg052djbZ7fFxtBmxA68vXlUnyGuUxI2EGGHhF3nNNddEfJw/zHj+QKNJ9NfXGV5jory+eBlZCULswOuLd2kJ8Bp140Z87AYBAABAp4aEBQAAACwvoRMWt9tNK1euVD8TUaK/vs7wGhP99cWrRP9c8Prin7sTvMaEPOkWAAAAEltCj7AAAABAYkDCAgAAAJaHhAUAAAAsDwkLAAAAWB4SFgAAALC8hE5Y1q5dSwMGDKDk5GQaO3Ysffzxx5QIVq1apcqIN1+GDh1K8Wzv3r00Y8YMVZ6ZX8+2bdvCHueL2VasWEFZWVmUkpJCU6ZMoRMnTlCivL758+df9pnecccdHdbfzixR40Yixg7EjfmdKm4kbMKyZcsWWrJkibpO/ciRIzRy5EiaNm0anT9/nhLBsGHDqKSkJLTs27eP4lldXZ36jPiPRWueffZZeumll2j9+vX00UcfUdeuXdXn2dho1uyKsX19jANN889006ZNMe0jJH7cSLTYgbhBnStuGAlqzJgxxqJFi0K3/X6/kZ2dbeTn5xvxbuXKlcbIkSONRMVfy3feeSd0OxAIGJmZmcbq1atD91VWVhput9vYtGmTEe+vj+Xm5hp33313h/UJEj9uJHrsQNxIfAk5wuLxeOjw4cNq+K/5JGd8+8CBA5QIeFiThwkHDhxI9913HxUWFlKiOnXqFJWWloZ9njxZFg/XJ8rnyXbv3k29e/emIUOG0IMPPkgVFRUd3aVOpTPEjc4UOxA3Ek9CJizl5eXk9/upT58+Yffzbf4Cxzve4DZu3Eg7duygdevWqQ1zwoQJaoruRBT8zBL18wwO677++uu0a9cueuaZZ2jPnj00ffp09T2G2Ej0uNHZYgfiRuJxdnQHQI6/kEEjRoxQQah///701ltv0f3339+hfYMrc++994b+P3z4cPW5Dho0SO09TZ48uUP7BokDsSOx3NvJ4kZCjrD06tWLHA4HnTt3Lux+vp2ZmUmJpnv37jR48GAqKCigRBT8zDrL58l4uJ6/x4n6mVpRZ4sbiR47EDcST0ImLC6Xi0aPHq2GyYICgYC6nZOTQ4mmtraWTp48qS7dS0TXXnutCjDNP8/q6mp11n8ifp6suLhYHYtO1M/Uijpb3Ej02IG4kXgS9pAQX5qYm5tLt956K40ZM4ZefPFFdYnYggULKN4tXbpUXZvPQ7lnz55Vl2DynuHcuXMpngNn870CPrZ+9OhR6tmzJ/Xr148efvhheuqpp+j6669XgWj58uXqxMGZM2dSvL8+Xh5//HG65557VIDlPyA/+clP6LrrrlOXYELsJHLcSMTYgbjxeOeKG0YC++Uvf2n069fPcLlc6nLFgwcPGolgzpw5RlZWlnpdV199tbpdUFBgxLMPP/xQXbbXcuHL9oKXKC5fvtzo06ePuixx8uTJxvHjx41EeH319fXG1KlTjYyMDCMpKcno37+/sXDhQqO0tLSju90pJWrcSMTYgbgxtVPFDRv/09FJEwAAAECnO4cFAAAAEgsSFgAAALA8JCwAAABgeUhYAAAAwPKQsAAAAIDlIWEBAAAAy0PCAgAAAJaHhAUAAAAsDwkLAAAAWB4SFgAAALA8JCwAAABAVvf/AFzg6Qh9JoIaAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.subplot(1, 2, 1)\n",
"plt.title(\"permeability\")\n",
"plt.imshow(k_train[0])\n",
"plt.subplot(1, 2, 2)\n",
"plt.title(\"field solution\")\n",
"plt.imshow(u_train[0])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "89a77ff1",
"metadata": {},
"source": [
"We now define the problem class for learning the Neural Operator. Since this task is essentially a supervised learning problem—where the goal is to learn a mapping from input functions to output solutions—we will use the `SupervisedProblem` class provided by **PINA**."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8b27d283",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-19T13:35:29.136572Z",
"start_time": "2024-09-19T13:35:29.134124Z"
}
},
"outputs": [],
"source": [
"# make problem\n",
"problem = SupervisedProblem(\n",
" input_=k_train.unsqueeze(-1), output_=u_train.unsqueeze(-1)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1096cc20",
"metadata": {},
"source": [
"## Solving the Problem with a Feedforward Neural Network\n",
"\n",
"We begin by solving the Darcy flow problem using a standard Feedforward Neural Network (FNN). Since we are approaching this task with supervised learning, we will use the `SupervisedSolver` provided by **PINA** to train the model."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e34f18b0",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-19T13:35:31.245429Z",
"start_time": "2024-09-19T13:35:29.154937Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (mps), used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b77243fe0274dada29b6bb5a15c47e8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
]
}
],
"source": [
"# make model\n",
"model = FeedForward(input_dimensions=1, output_dimensions=1)\n",
"\n",
"\n",
"# make solver\n",
"solver = SupervisedSolver(problem=problem, model=model, use_lt=False)\n",
"\n",
"# make the trainer and train\n",
"trainer = Trainer(\n",
" solver=solver,\n",
" max_epochs=10,\n",
" accelerator=\"cpu\",\n",
" enable_model_summary=False,\n",
" batch_size=10,\n",
" train_size=1.0,\n",
" val_size=0.0,\n",
" test_size=0.0,\n",
")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "7b2c35be",
"metadata": {},
"source": [
"The final loss is relatively high, indicating that the model might not be capturing the solution accurately. To better evaluate the model's performance, we can compute the error using the `LpLoss` metric."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0e2a6aa4",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-19T13:35:31.295336Z",
"start_time": "2024-09-19T13:35:31.256308Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final error training 28.54%\n",
"Final error testing 28.58%\n"
]
}
],
"source": [
"from pina.loss import LpLoss\n",
"\n",
"# make the metric\n",
"metric_err = LpLoss(relative=False)\n",
"\n",
"model = solver.model\n",
"err = (\n",
" float(\n",
" metric_err(u_train.unsqueeze(-1), model(k_train.unsqueeze(-1))).mean()\n",
" )\n",
" * 100\n",
")\n",
"print(f\"Final error training {err:.2f}%\")\n",
"\n",
"err = (\n",
" float(metric_err(u_test.unsqueeze(-1), model(k_test.unsqueeze(-1))).mean())\n",
" * 100\n",
")\n",
"print(f\"Final error testing {err:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "6b5e5aa6",
"metadata": {},
"source": [
"## Solving the Problem with a Fourier Neural Operator\n",
"\n",
"We will now solve the Darcy flow problem using a Fourier Neural Operator (FNO). Since we are learning a mapping between functions—i.e., an operator—this approach is more suitable and often yields better performance, as we will see."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9af523a5",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-19T13:35:44.717807Z",
"start_time": "2024-09-19T13:35:31.306689Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (mps), used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6fbb56905e4c4799973669f533a2d73c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
]
}
],
"source": [
"# make model\n",
"lifting_net = torch.nn.Linear(1, 24)\n",
"projecting_net = torch.nn.Linear(24, 1)\n",
"model = FNO(\n",
" lifting_net=lifting_net,\n",
" projecting_net=projecting_net,\n",
" n_modes=8,\n",
" dimensions=2,\n",
" inner_size=24,\n",
" padding=8,\n",
")\n",
"\n",
"\n",
"# make solver\n",
"solver = SupervisedSolver(problem=problem, model=model, use_lt=False)\n",
"\n",
"# make the trainer and train\n",
"trainer = Trainer(\n",
" solver=solver,\n",
" max_epochs=10,\n",
" accelerator=\"cpu\",\n",
" enable_model_summary=False,\n",
" batch_size=10,\n",
" train_size=1.0,\n",
" val_size=0.0,\n",
" test_size=0.0,\n",
")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "84964cb9",
"metadata": {},
"source": [
"We can clearly observe that the final loss is significantly lower when using the FNO. Let's now evaluate its performance on the test set.\n",
"\n",
"Note that the number of trainable parameters in the FNO is considerably higher compared to a `FeedForward` network. Therefore, we recommend using a GPU or TPU to accelerate training, especially when working with large datasets."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "58e2db89",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-19T13:35:45.259819Z",
"start_time": "2024-09-19T13:35:44.729042Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final error training 3.52%\n",
"Final error testing 3.67%\n"
]
}
],
"source": [
"model = solver.model\n",
"err = (\n",
" float(\n",
" metric_err(u_train.unsqueeze(-1), model(k_train.unsqueeze(-1))).mean()\n",
" )\n",
" * 100\n",
")\n",
"print(f\"Final error training {err:.2f}%\")\n",
"\n",
"err = (\n",
" float(metric_err(u_test.unsqueeze(-1), model(k_test.unsqueeze(-1))).mean())\n",
" * 100\n",
")\n",
"print(f\"Final error testing {err:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "26e3a6e4",
"metadata": {},
"source": [
"As we can see, the loss is significantly lower with the Fourier Neural Operator!"
]
},
{
"cell_type": "markdown",
"id": "ba1dfa4b",
"metadata": {},
"source": [
"## What's Next?\n",
"\n",
"Congratulations on completing the tutorial on solving the Darcy flow problem using **PINA**! There are many potential next steps you can explore:\n",
"\n",
"1. **Train the network longer or with different hyperparameters**: Experiment with different configurations of the neural network. You can try varying the number of layers, activation functions, or learning rates to improve accuracy.\n",
"\n",
"2. **Solve more complex problems**: The Darcy flow problem is just the beginning! Try solving other complex problems from the field of parametric PDEs. The original paper and **PINA** documentation offer many more examples to explore.\n",
"\n",
"3. **...and many more!**: There are countless directions to further explore. For instance, you could try to add physics informed learning!\n",
"\n",
"For more resources and tutorials, check out the [PINA Documentation](https://mathlab.github.io/PINA/)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pina",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 5
}