{ "cells": [ { "cell_type": "markdown", "id": "e80567a6", "metadata": {}, "source": [ "# Tutorial: Two dimensional Darcy flow using the Fourier Neural Operator\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial5/tutorial.ipynb)\n" ] }, { "cell_type": "markdown", "id": "8762bbe5", "metadata": {}, "source": [ "In this tutorial we are going to solve the Darcy flow problem in two dimensions, presented in [*Fourier Neural Operator for\n", "Parametric Partial Differential Equation*](https://openreview.net/pdf?id=c8P9NQVtmnO). First of all we import the modules needed for the tutorial. Importing `scipy` is needed for input-output operations." ] }, { "cell_type": "code", "execution_count": 1, "id": "5f2744dc", "metadata": { "ExecuteTime": { "end_time": "2024-09-19T13:35:28.837348Z", "start_time": "2024-09-19T13:35:27.611334Z" } }, "outputs": [], "source": [ "## routine needed to run the notebook on Google Colab\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "if IN_COLAB:\n", " !pip install \"pina-mathlab\"\n", " !pip install scipy\n", " # get the data\n", " !wget https://github.com/mathLab/PINA/raw/refs/heads/master/tutorials/tutorial5/Data_Darcy.mat\n", "\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import warnings\n", "\n", "# !pip install scipy # install scipy\n", "from scipy import io\n", "from pina.model import FNO, FeedForward # let's import some models\n", "from pina import Condition, Trainer\n", "from pina.solver import SupervisedSolver\n", "from pina.problem.zoo import SupervisedProblem\n", "\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "id": "4cf5b181", "metadata": {}, "source": [ "## Data Generation\n", "\n", "We will focus on solving a specific PDE, the **Darcy Flow** equation. The Darcy PDE is a second-order elliptic PDE with the following form:\n", "\n", "$$\n", "-\\nabla\\cdot(k(x, y)\\nabla u(x, y)) = f(x) \\quad (x, y) \\in D.\n", "$$\n", "\n", "Specifically, $u$ is the flow pressure, $k$ is the permeability field and $f$ is the forcing function. The Darcy flow can parameterize a variety of systems including flow through porous media, elastic materials and heat conduction. Here you will define the domain as a 2D unit square Dirichlet boundary conditions. The dataset is taken from the authors original reference.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "2ffb8a4c", "metadata": { "ExecuteTime": { "end_time": "2024-09-19T13:35:28.989631Z", "start_time": "2024-09-19T13:35:28.952744Z" } }, "outputs": [], "source": [ "# download the dataset\n", "data = io.loadmat(\"Data_Darcy.mat\")\n", "\n", "# extract data (we use only 100 data for train)\n", "k_train = torch.tensor(data[\"k_train\"], dtype=torch.float)\n", "u_train = torch.tensor(data[\"u_train\"], dtype=torch.float)\n", "k_test = torch.tensor(data[\"k_test\"], dtype=torch.float)\n", "u_test = torch.tensor(data[\"u_test\"], dtype=torch.float)\n", "x = torch.tensor(data[\"x\"], dtype=torch.float)[0]\n", "y = torch.tensor(data[\"y\"], dtype=torch.float)[0]" ] }, { "cell_type": "markdown", "id": "9a9defd4", "metadata": {}, "source": [ "Let's visualize some data" ] }, { "cell_type": "code", "execution_count": 3, "id": "c8501b6f", "metadata": { "ExecuteTime": { "end_time": "2024-09-19T13:35:29.108381Z", "start_time": "2024-09-19T13:35:29.031076Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAEjCAYAAAARyVqhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8ekN5oAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2+UlEQVR4nO3dC3xU5Zk/8GcumUmAJIAEkigXQQFFLoqFBqFIoSDrolBLkbUlUKW7Luzqhw/Wxg83LzVV1kstLFi3iK4VkFZhu7psEQXKAiograyWJRRIIgkkgdwvczv/z/PynzETMpn3wZzJmcnv+/kcwsy8c/LO5Tx5znvOeV6bYRgGAQAAAFiYvaM7AAAAABANEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAAFgeEhaIW6tWrSKbzUbl5eVR2w4YMIDmz58fur179271XP4ZxI9zO4BE98knn9C4ceOoa9euajs4evRoaHu6ErrbzunTp9Xv2LhxI8UK/y7+nfy729Ptt9+uFogdJCwAEdTX16sg3jypAYh3Xq+XZs+eTRcuXKAXXniB/v3f/5369+/f0d2ypM8//1zFgPZOduDKOK/weQBx5fjx42S3t52fv/LKKxQIBMISlscff1z9H3tSkChOnjxJZ86cUd/3Bx54IHT/smXL6Kc//WmH9s2KCQvHAN7+W44g/eEPf+iwfnVWSFggKv7D3aVLF4pnbrc7apukpKSY9AWgI50/f1797N69e9j9TqdTLaDH5XJ1dBc6HRwSinPB485/+ctf6Pvf/z6lpaXRVVddRQ899BA1NjaGtX3jjTdo9OjRlJKSQj179qR7772XioqKwtrwnsRNN91Ehw8fpm9961sqUXnsscdCx57/5V/+hdauXUsDBw5Uj02dOlWtgyf9fvLJJ+maa65R67/77rvVkHNL//Vf/0UTJkxQx85TU1PpzjvvpP/93/8Na/PnP/9ZHRPn35GcnEyZmZn0ox/9iCoqKlp9D/gclmivveU5LNGOw/PrzcjIUP/nPSx+7bzw+/3qq6+q/3/66aeXrePpp58mh8NBX375ZZu/C6Aj8Hd84sSJ6v98WIi/x8HRw0jnsOjEjdZUVlaq35eenq6So9zcXHWf7mEr3u6uv/56FQN4ux4/fjzt3LkzrN0HH3wQiif8OzjufPHFF1HXH9yWW2oeJ/jcF36P2KRJk0IxIHiIuLVzWDgZvP/++6lPnz6q3yNHjqTXXnstrE3zWPqrX/2KBg0apHaovvGNb6hziyAypNMJgv9g88aWn59PBw8epJdeeokuXrxIr7/+unr8Zz/7GS1fvly142HgsrIy+uUvf6mSEv7D23xvixOD6dOnq8D0gx/8QG18Qb/5zW/I4/HQP/3TP6mE5Nlnn1Xr/Pa3v6025EcffZQKCgrUupcuXUobNmwIPZePlXPQmjZtGj3zzDNq5GbdunUqEHEfgskCB6W//vWvtGDBApWscELDGzb/5NfWMqhGe+1XgpMV7tuDDz5Is2bNou9+97vq/hEjRtC1115LixYtUu/FzTffHPY8vo+D2NVXX33FvxvALH//93+vvpucWP/zP/+z+iPZfPtuSRI3muMdGE4e9u3bR//wD/9AN9xwA73zzjtq+9fByQRvz/w7x4wZQ9XV1XTo0CE6cuQIfec731Ft3n//fRWneMeG2zc0NKi+3Xbbbard1z2Bnl8jv0ccT3injV8DC/5siX8/b/sc/xYvXqzixNatW1UCxIka70g19+abb1JNTY36TDimcSzlOMOxD6O9ERgQ11auXGnwx3jXXXeF3f+P//iP6v4//elPxunTpw2Hw2H87Gc/C2vz2WefGU6nM+z+iRMnquetX78+rO2pU6fU/RkZGUZlZWXo/ry8PHX/yJEjDa/XG7p/7ty5hsvlMhobG9Xtmpoao3v37sbChQvD1ltaWmqkp6eH3V9fX3/Z69y0aZP6PXv37hW99qD+/fsbubm5odsffvihasM/g/hxbhdUVlam2vDvaYlfX3Z2tuH3+0P3HTlyRLV/9dVXL2sPYBXB7/7WrVvD7g9uT0GSuNFy29m2bZta17PPPhu6z+fzGRMmTNDaRjie3HnnnW22GTVqlNG7d2+joqIidB9v83a73Zg3b17oPv5d/Ds5hgVF2q5bxgl+j1rGieaxkpegF198UbV94403Qvd5PB4jJyfH6Natm1FdXR0WS6+66irjwoULobbbt29X9//+979v83V3ZjgklCB4j785HgFh7733Hr399tvqZFLeS+LDJ8GFRy94yPXDDz8Mey4PT/LoRmt4iJSHeIPGjh2rfvJITPPj33w/j8QED43wqAnvZcydOzesD3z4hNs27wMPPQfxoR1u981vflPd5j0nyWs3y7x58+js2bNh/ebRFe77PffcY9rvBYgVadxojrc9jgc8QhnE23pw24yGR254RPXEiROtPl5SUqIuxebRCz5MFcQjoDwCY+a2Hwn/Tn5vOMYF8UgJj9LU1tbSnj17wtrPmTOHevToEbrNh7YYj7BA63BIKEFwAGmOj4vyVTF8vJR/8k5FyzZBLYcfecg40gll/fr1C7sdTF769u3b6v18aIYFAw8fOmoNn38SxIea+Pj15s2bQycIBlVVVYleu1k4KGZlZakkZfLkySqwb9q0SQ2D87k5APGOt1lJ3GiOr0Li7aNbt25h9w8ZMkTrdz/xxBNqWxo8eLA6p+6OO+6gH/7whyohCa4/0vr4kM1///d/U11dnTq3JVa4T/xetbwaMXgIKdjnSLE0mLwEYyZcDglLgmp+ngf/MeXbfMIr7+W01DKoNB/haKm157d1/6XR10t9CJ7HwnshLTUfneE9uv3799MjjzxCo0aNUv3j53PQan7ZcSRXWvxKgl/v3/3d36lLQ//1X/+V/ud//keNuPBIE0AikMaN9sTnj/Dl19u3b1eXD//bv/2bqhmzfv36sEux25vf76dYiRYz4XJIWBJob4hP8griE7844PCJZ7xh8EbAj/MeS0fgUQ/Wu3dvmjJlSsR2vHexa9cuNcKyYsWK0P2RhoajvfavI1riw4eFnnvuOfr973+vgjqfqMsnFAMkAt5mrzRucCE63o75UEjzxIbrIeniQz18aJoXXg8nMXxyLScswUJ3ra2Pr5js1atXm6MrPJrR8oolPoTNh5qudOeH+8RXOHLsaT7Kwv0JPg5fD85hSRB8qXFzfLY847Po+cxzTlo4CWiZvfPtSJcLtyf+Q86HffjqBL5ksSW++qD5XkfLfr744otX9Nq/jmDtmUiXYvLwNC+89/e73/1OXVWFOhaQKL5O3Pibv/kb8vl86kq75qMXwW0zmpbr5qTnuuuuo6amJnWbDzfx6CtfMtx8+zx27JgakeHfHy0Z27t3b9h9fCViyxGWYNKjczk2/87S0lLasmVL6D5+D/g1c/+Dl5PDlUN0TRCnTp2iu+66Sx02OXDggKqdwIcsuA4Ae+qppygvL0+d1zFz5kx1ngU/hy81/PGPf6wuQTYTJyscvPg49C233KL+uPOIRGFhIb377rvqUsQ1a9aodrwnxZf4cWLD59NwAOK+Xulrv1J8aOzGG29UAYj3MHmPj4+n89J8lCX43uFwECQS/qN+pXFjxowZapvmyrn8XN6O+CTe1s5Baw2350uEuf4Lb3d8SfNvf/tbdblw0OrVq9VOSU5Ojqp9Erysmc+fa63GSnM8SsOXW/MJ8nw+2p/+9Cd13guPzDTHSREnbVyGgfvOFyTweXg8UtwSvx8vv/yyOhGY61jxCC/3mQ8X8w4Xzm1rBx19mRJ8PcFLET///HPje9/7npGammr06NHDWLx4sdHQ0BDW9ne/+50xfvx4o2vXrmoZOnSosWjRIuP48eOhNnyZ3rBhwy77PcFL8VavXq11iWTwUsJPPvnksvbTpk1TlzInJycbgwYNMubPn28cOnQo1Ka4uNiYNWuWugya282ePds4e/bsZZciSl77lVzWzPbv32+MHj1aXaLd2qWQJSUl6tLPwYMHX/aeAcTzZc2SuNHatsOXG//whz800tLS1HbM///000+1Lmt+6qmnjDFjxqgYkJKSon4nX0bNlwk39/777xu33XabasO/Z8aMGSoeNNfaZc1cjuDRRx81evXqZXTp0kXFpIKCgsviBHvllVeMgQMHqu28ecxoeVkzO3funLFgwQK1Xo4Zw4cPv+y1RoqlbV1uDZfY+J/2SHygY/CeBA/Z8iGVlnsHYD6+zJOHp/l8Gy6wBQAA5sA5LABfA5fv5uPefKgLAADMg3NYAK4Az2HCM7ly6XI+tv91r0gCAIC2IWEBuAJc2IprxfCJhbpXPgAAwJXDOSwAAABgeTiHBQAAACwPCQsAAABYXkKcw8KlkHkeFy7ME4t5ZADgcnx0uaamhrKzsy+bAM6qEDsA4iduJETCwgGn5WzBANAxioqK6JprrqF4gNgBED9xw7SEhed34dLJPLcCl0jnKynGjBkTsf3WrVtV4S0u48xTdHMp5GjzQQQFSx5fs3IZ2ZOTo7a3RZ/wN0zXQsHeomAnrfqGy+fUicRdEnkq96+ry3n9866TamXnaKeU6b9Gsuu/eQGHbG+4LlP/q558Uf8L4mzQb2sTnt/uqrw0b4qOQFLrM7+2xlkQPsFbNP7ycq12PvLSPnrva5Ugj2XcYMG+njkygNK6Rd/Oy/x1JFHu148dFwKRZ0lv6aI/8sR+LVUJ1qvW7dOfhbnCq9+Pco9sdufyJv325Q2X5v3ScaFGvy3zVEX/mxLkrNSPM+4K/RiWXCGMu+X6s06nnKvXbusouSDqh6/0XLvGDVMSFp57ZcmSJWoq8LFjx6p5FHjyO55Zs7U5GPjy0Llz51J+fj797d/+Lb355puqtsWRI0fC5m2JJDiUy8mKGQmLw21OwmJP0f8j43Cbl7A4VNX59m/LnE6HOQmLU5awOFz6X3VnkiBh8QkSloDwvRMkZQHB++y0u0T9sNk0v3v//+Vd6aGVWMeN5n3lZCUtNfp23ihIQKTtmwKCtn79z9vjl4X5Bp/gj65XPy65PLLvXZJTv73T7tZu6/DrJyDM7tFvb2/Uf+8cbpt5cTdJP2FxOvTbOoSxg3RihyBumHKg+fnnn6eFCxeqacF5EisOQDzz7YYNG1pt/4tf/EJNXPfII4/QDTfcQE8++aSaII8nwwOAzgFxAwBimrB4PB41U+WUKVO++iV2u7rNM+m2hu9v3p7xnlWk9jzFeHV1ddgCAPErFnGDIXYAxC+7GZPB8dwqffr0Cbufb/Nx6dbw/ZL2PATMU4gHF5w0BxDfYhE3GGIHQPyKj2sPW8jLy6OqqqrQwmcXAwBEg9gBEL/a/aTbXr16kcPhoHPnws8O5tuZmZmtPofvl7R3u91qAYDEEIu4wRA7AOJXu4+wuFwuGj16NO3atSusOBPfzsnJafU5fH/z9mznzp0R2wNAYkHcAIAOuayZL03Mzc2lW2+9VdVQ4MsT6+rq1Nn/bN68eXT11Ver48nsoYceookTJ9Jzzz1Hd955J23evJkOHTpEv/rVr8zoHgBYEOIGAMQ8YZkzZw6VlZXRihUr1Alwo0aNoh07doROkCssLAwrwTtu3DhVQ2HZsmX02GOPqQJQ27Zt066lEORssJHdiH4tt7NWVieiZpB+rY3kvjX6K67QL2DUmO2TFYMrFNQd0a8bRIEk4XvXT/+6fY2P7qu2gvIu0to7Vdfqr9wm6Ei3Yv16B6ypu35RrtST+t874+oMUT9smlfS2Aw7kX6tO8vEDYA2mTVjg1VmgrBJiod17GmvNoML+cc5vjSRz/gfuOJnWoXjpAlLYx9zEpZ6QcJCAZtpCUu3IsO0RIH/hsVbwuJJ0++ITbD1SBMWSSCRJCw2v2yTN744qdXOZ3jpw6a31MmsaWlpFE+x4+L/DdQqHHdeWOm2TFA4rkJQkfaCX78KbKVfVtn1gqDSbblXv22ZR1YBuaxRsO4G/eS+olq/LWuSVLq9IKioLal0WybbZruUCSrdlppY6fbLs9HbGF7aTdu14kZcXiUEAAAAnQsSFgAAALA8JCwAAABgeUhYAAAAwPKQsAAAAIDlIWEBAAAAy0PCAgAAAJaHhAUAAAAsDwkLAAAAWB4SFgAAAOiccwl1FE+Gj+wp0efc8fSRlTm+YVD08sJBxVXp2m3Tetdqt236c3cyqyR+bV/9EtF22ZRGIgFBuf2A/hRFij9ZMP2AoN6+ZIqAgFM2n0CXc/r9aOyjX369y8mLon7Yuut9p+0BD9F50aohBvySYMDtBZPc+AX7vAHJ3BsWmTrn0hMEfy9sJk1bIh1asEnaWuNz0YERFgAAALA8JCwAAABgeUhYAAAAwPKQsAAAAIDlIWEBAAAAy0PCAgAAAJaHhAUAAAA6X8KSn59P3/jGNyg1NZV69+5NM2fOpOPHj7f5nI0bN5LNZgtbkpOT27trAGBRiBsAEPOEZc+ePbRo0SI6ePAg7dy5k7xeL02dOpXq6urafF5aWhqVlJSEljNnzrR31wDAohA3ACDmlW537Nhx2V4Q7zEdPnyYvvWtb0V8Hu8dZWZmtnd3ACAOIG4AQIeX5q+qqlI/e/bs2Wa72tpa6t+/PwUCAbrlllvo6aefpmHDhrXatqmpSS1B1dXV6qfN5VdLNH16X+qTria//ts0oId+2fNjRwdot7V1k00nQKTf3u7RL81sM7PcfkpAu62REv1zbs6hMWVDkNOpv25Pvf4cAXXJss3N7tF/85IvCsqpp6eI+uGo0PxOB2SfSazjRluxI5FJyuezgGGNcvsBcbTRY5OU2ldP0G9q2AXTethtph0LMQRx14ijMv6mnnTLQeThhx+m2267jW666aaI7YYMGUIbNmyg7du30xtvvKGeN27cOCouLo54vDs9PT209O3b18RXAQCxZFbcYIgdAPHL1ISFj0kfO3aMNm/e3Ga7nJwcmjdvHo0aNYomTpxIb7/9NmVkZNDLL7/cavu8vDy1BxZcioqKTHoFABBrZsUNhtgBEL9MOyS0ePFi+s///E/au3cvXXPNNaLnJiUl0c0330wFBQWtPu52u9UCAInFzLjBEDsA4le7j7AYhqGCzjvvvEMffPABXXvtteJ1+P1++uyzzygrK6u9uwcAFoS4AQAxH2Hh4dw333xTHVfmmgqlpaXqfj5enJJy6WQ/Hsa9+uqr1fFk9sQTT9A3v/lNuu6666iyspJWr16tLk984IEH2rt7AGBBiBsAEPOEZd26dern7bffHnb/q6++SvPnz1f/LywsJLv9q8Gdixcv0sKFC1WQ6tGjB40ePZr2799PN954Y3t3DwAsCHEDAGKesPDQbjS7d+8Ou/3CCy+oBQA6J8QNAIgGcwkBAACA5SFhAQAAAMtDwgIAAACWZ3pp/lhypXjJ0SV6TeK+qZWi9doFpZw/+uw6/RW7BaXoHcJy0j79EsoBSaV2p6wfdkFJ/G7dGrXbZnRre1K8y9qn1JIZSuvStNueKeolWrenu3597UCSYHoFr/73DmJHUhLfI6i9Li2f7zepNL9ZpfaZrLq8LIbZBLFXVBJfMm2J8C91wKH/jhgOwbhFIpfmBwAAAGgPSFgAAADA8pCwAAAAgOUhYQEAAADLQ8ICAAAAloeEBQAAACwPCQsAAABYHhIWAAAAsDwkLAAAAGB5SFgAAADA8hKqNL8RsFEgEL108PHy3qL11hSnmVO6PklQIl3jdYVx6a/b7vJrt03u4hF1IyNVv4T+oLRy7bZDu5WI+pGdpD8dQ13Ard32I8dA7balqakkYQu4tNv6BaX57XWNsn449cKELYCS/y35BUXjJW0lZfy9hizMewU14/2CfV6/cIoAs4iry9sFpfkF8V9Sbj/glHXaEJXmF6zbjtL8AAAAAG1CwgIAAACdL2FZtWoV2Wy2sGXo0KFtPmfr1q2qTXJyMg0fPpzee++99u4WAFgY4gYAdMgIy7Bhw6ikpCS07Nu3L2Lb/fv309y5c+n++++nTz/9lGbOnKmWY8eOmdE1ALAoxA0AiHnC4nQ6KTMzM7T06tUrYttf/OIXdMcdd9AjjzxCN9xwAz355JN0yy230Jo1a8zoGgBYFOIGAMQ8YTlx4gRlZ2fTwIED6b777qPCwsKIbQ8cOEBTpkwJu2/atGnq/kiampqouro6bAGA+GZ23GCIHQDxq90TlrFjx9LGjRtpx44dtG7dOjp16hRNmDCBampqWm1fWlpKffr0CbuPb/P9keTn51N6enpo6du3b3u/DACIoVjEDYbYARC/2j1hmT59Os2ePZtGjBih9nj4RLjKykp666232u135OXlUVVVVWgpKipqt3UDQOzFIm4wxA6A+GV64bju3bvT4MGDqaCgoNXH+Vj1uXPnwu7j23x/JG63Wy0AkJjMiBsMsQMgfpleh6W2tpZOnjxJWVlZrT6ek5NDu3btCrtv586d6n4A6JwQNwDA9IRl6dKltGfPHjp9+rS69HDWrFnkcDjUJYhs3rx5alg26KGHHlLHrZ977jn6y1/+ouoxHDp0iBYvXtzeXQMAi0LcAICYHxIqLi5WQaaiooIyMjJo/PjxdPDgQfV/xmf+2+1f5Unjxo2jN998k5YtW0aPPfYYXX/99bRt2za66aabxL/bMC4t0dScSRet1+7VbxtIFqzY5zBljgrFod/eKZhLqEfXBlE3hnQPH7ZvS07aSe22tyRHvoKkNRl2n3bbIr/+IYPCpqu029oFc5IwydQrDo/w+yEQqKzSa2fI5pmyStwwU0DwIUrm/PGI5vuRzf8SELT3BRymzH90qb0589bY7bI5r2yC7TYgiLuSmG4I5xIKCP6yG07B59JsG0yIhGXz5s1tPr579+7L7uOT7XgBgM4JcQMAosFcQgAAAGB5SFgAAADA8pCwAAAAgOUhYQEAAADLQ8ICAAAAloeEBQAAACwPCQsAAABYHhIWAAAAsDwkLAAAAGB5ps/WHEuuT7uRwx29Nn5Skmy9AcHkrka9oIyzoB8Bl6z0ut9mTqn2rkmy8uuZ7mrttgNd5/XbCr+53ezdtNvWG7XabR02/TLffr9s/8ApmBLCFhCU+U6WbQABj15HAob+9Afxyq8z90czHsE+oaTcvqSMv1ewXtVeUm5fUMZfWmrfrNL8UpLS/JIhAMHbLCq1zwKCUv6GQ/A+2zr2M8EICwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACdL2EZMGAA2Wy2y5ZFixa12n7jxo2XtU1Ojl5LBQASC2IHAMS0cNwnn3xCfr8/dPvYsWP0ne98h2bPnh3xOWlpaXT8+PHQbQ48ANC5IHYAQEwTloyMjLDbP//5z2nQoEE0ceLEiM/hIJOZmdneXQGAOILYAQAdVprf4/HQG2+8QUuWLGlzz6e2tpb69+9PgUCAbrnlFnr66adp2LBhEds3NTWpJai6+lL5d65ArVOF2l0pK69t2G2mlESWlPz3pcj2HCVVrb3J+l+DWq9L1I96v377esEbUm/UifpBgUbtpmX+FO22FR79kv8+j6xEuktQ6d7u1f9O2xpk0yvYXXql/O1ctl7/bbZU7ND11fiPnoChf9Q9IDhC7xe09QrruktK+fsFgUZSxt9MNjPbC8r4G6aW5ieTyvh37Gmvpv72bdu2UWVlJc2fPz9imyFDhtCGDRto+/btKkBx4Bk3bhwVFxdHfE5+fj6lp6eHlr59+5r0CgCgIyB2AEBME5Zf//rXNH36dMrOzo7YJicnh+bNm0ejRo1SQ79vv/22Ghp++eWXIz4nLy+PqqqqQktRUZFJrwAAOgJiBwDE7JDQmTNn6P3331dBRCIpKYluvvlmKigoiNjG7XarBQASD2IHAMR0hOXVV1+l3r1705133il6Hl8l8Nlnn1FWVpZZXQMAC0PsAICYJSx8LJmDTm5uLjmd4YM4PITLw7JBTzzxBP3hD3+gv/71r3TkyBH6wQ9+oPawHnjgATO6BgAWhtgBADE9JMTDuYWFhfSjH/3ossf4frv9qzzp4sWLtHDhQiotLaUePXrQ6NGjaf/+/XTjjTea0TUAsDDEDgCIacIydepUMvgSx1bs3r077PYLL7ygFgAAxA4AiARzCQEAAIDlIWEBAAAAy0PCAgAAAJaHhAUAAAA691xCsZZ22k/OJOlsH+08l5DetCuKz62/Xns32QwYtoB+Ltro0J/vp8ytP3cO+yJZf2K6dGeDdluPZCIOIkq2ebXbft50tXbbgppe2m0DdYIvB3/mkil/JNNjNbvSRoctWa/Qmo3nlWmnuYSsSjBlk+IXzEQj+U5L5vuRzuETEMwP5AsI+iGZ4EzYXvixiNhsgrVLXqJovcK55Ozm/H2T9qO9YYQFAAAALA8JCwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAAFheQpXmT77oI6fTF7VdwCkscy+ooCxZtz9ZP190Nsn6bPcKclFBGf8mo4uoH38J9NFuW92UrN32825Zon6kOPRL859rTNVue7q8p3ZbZ7VsOgGHxzDlO0qSUtwsSXPqhoBstZ2BX7BPGDDMaeuX1GmX9llQi15amt8Qttder7S9Sf2QlPGXdsGQfOSi6QRQmh8AAACgfROWvXv30owZMyg7O5tsNhtt27Yt7HHDMGjFihWUlZVFKSkpNGXKFDpx4kTU9a5du5YGDBhAycnJNHbsWPr444+lXQMAi0LcAICYJyx1dXU0cuRIFSha8+yzz9JLL71E69evp48++oi6du1K06ZNo8bGyNO4btmyhZYsWUIrV66kI0eOqPXzc86fPy/tHgBYEOIGAMQ8YZk+fTo99dRTNGvWrMse472kF198kZYtW0Z33303jRgxgl5//XU6e/bsZXtUzT3//PO0cOFCWrBgAd14440qaHXp0oU2bNggf0UAYDmIGwBgqXNYTp06RaWlpWo4Nyg9PV0N1R44cKDV53g8Hjp8+HDYc+x2u7od6TlNTU1UXV0dtgBAfIpV3GCIHQDxq10TFg46rE+f8CtD+HbwsZbKy8vJ7/eLnpOfn68CWnDp27dvu70GAIitWMUNhtgBEL/i8iqhvLw8qqqqCi1FRUUd3SUAiAOIHQDxq10TlszMTPXz3LlzYffz7eBjLfXq1YscDofoOW63m9LS0sIWAIhPsYobDLEDIH61a8Jy7bXXqmCxa9eu0H18jJjP+s/JyWn1OS6Xi0aPHh32nEAgoG5Heg4AJA7EDQAwpdJtbW0tFRQUhJ0wd/ToUerZsyf169ePHn74YXU1wPXXX68C0fLly1XthZkzZ4aeM3nyZHW1wOLFi9VtvjQxNzeXbr31VhozZoy6YoAvg+Sz/wEg/iFuAEDME5ZDhw7RpEmTQrc5aDAOHBs3bqSf/OQnKmj8+Mc/psrKSho/fjzt2LFDFXYKOnnypDppLmjOnDlUVlamCkfxCXOjRo1Sz2l5Ql00SVVN5NSofh5IThKt1+bXrzluOARlrRv02zoapWXdJW31yy07PLJBuab6FO22RZWaJeCJqLhrD1E/7En6n2HAI3ivq/U3oeRq4fQKgtL8ktLdRpLsu2Tvpjcdgy3gIPpqs46buCEhnX1AVEJfMODtl5TEF9Vel5fQN2u9ZpX9l5baN8xqLJ0jQEA0VUccsRlcBCHO8fAxn/E/6eafktPhjp+ExaXf1pci+yPj7abfvilNfwNu6iHb2Jt66H+9vN392m1tXaPPGWW5hOW87DPsUqL/3nU7q/9+pHxZI+qHvaZBq50v0ETvn16jTmaNl3NDgrHj4v8NpLTU6NvjSW+taP1FPv334UuffhJe5tOf7+qir6t2W9Xeqz9fWI1Pf/6vKo9+W7Vur377Wo/+zk5tY/S/Ec01NOiv21+r/7fFIZhbzHVRtqPY5bwgdnypHzu6nLwo6of/+FejqpH4DC/tpu1acSMurxICAACAzgUJCwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAACTeXEKJwFHVKHyCOXNr2AVzutibhHMJNel/tM4G/XUn1clyXHel/nvn7abfZ1+K7KsbcJkzx4ddMGdTkqyqu2g+KMncIf5ustLkZNf7zAP+ThlOLM8vmM9IPP+RYF4eyXrNnB9IOpeQaKKugH5bm6itfhesNKdRe8MICwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAgMRLWPbu3UszZsyg7OxsstlstG3bttBjXq+XHn30URo+fDh17dpVtZk3bx6dPXu2zXWuWrVKrav5MnTo0Ct7RQBgOYgbABDzhKWuro5GjhxJa9euveyx+vp6OnLkCC1fvlz9fPvtt+n48eN01113RV3vsGHDqKSkJLTs27dP2jUAsCjEDQD4usSVnqZPn66W1qSnp9POnTvD7luzZg2NGTOGCgsLqV+/fpE74nRSZmamtDsAEAcQNwDA8uewVFVVqaHa7t27t9nuxIkTaih44MCBdN9996lAFUlTUxNVV1eHLQCQOMyIGwyxAyB+mVpLu7GxUR2bnjt3LqWlpUVsN3bsWNq4cSMNGTJEDes+/vjjNGHCBDp27BilpqZe1j4/P1+1acle7yG7Thl9m7A0s7Qssiabx6fftlFWmt/eoL9uZ52gNH+N7CvjT9Zfty9FUBLcLfsM/S5BCXHZW63Nrv+RKEn1+l88b1f99y7glJXmd9n13jufL2DpuNFW7IArIym3HyCbaaX5/QFBPwQl8VV7v6Q0P5lSbt/u12+r1u03Z1oP0owFcTfCwifSff/73yfDMGjdunVttuWh4tmzZ9OIESNo2rRp9N5771FlZSW99dZbrbbPy8tTe2DBpaioyKRXAQCxZGbcYIgdAPHLaWbQOXPmDH3wwQdt7iW1hoeBBw8eTAUFBa0+7na71QIAicPsuMEQOwDil92soMPHlt9//3266qqrxOuora2lkydPUlZWVnt3DwAsCHEDANo9YeGgcPToUbWwU6dOqf/zyW4cdL73ve/RoUOH6De/+Q35/X4qLS1Vi8fjCa1j8uTJ6iqAoKVLl9KePXvo9OnTtH//fpo1axY5HA51DBsA4h/iBgDE/JAQB5VJkyaFbi9ZskT9zM3NVYWc/uM//kPdHjVqVNjzPvzwQ7r99tvV/3kvqLy8PPRYcXGxCjIVFRWUkZFB48ePp4MHD6r/A0D8Q9wAgJgnLBw8+IS4SNp6LIj3iJrbvHmztBsAEEcQNwDg68JcQgAAAGB5SFgAAADA8pCwAAAAgOUhYQEAAIDOXZo/1mweL9ns0XMwI0VYOMrjFXRCULpY40TDIHuToA+8akkJ5Tr9vNXhlNWtN5IcprQNuGVf3YBLUPZf0NZwmleq2ttV//3wJ+mv19koLKGv+52WTnkBnYqk1D4zRKX5BW39sv10Q1Ca3+YTtJX8WfGZNw2ILSCoze83aZ4aTRhhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWl1Cl+Rv79SSnMzlqO/eXVbIVu/TrntsamsgUfr+oua1BVsrfLDpTJYQIyv7bk5ymfYaGS3/dAZd+nz09hFNCCKZucFcJymsL6ZbuFpX4hk5Xbl9amt8XsJvSNiAszU9e/fZ2r82kttpNL7X362+LNp9guw2gND8AAABAm5CwAAAAQOIlLHv37qUZM2ZQdnY22Ww22rZtW9jj8+fPV/c3X+64446o6127di0NGDCAkpOTaezYsfTxxx9LuwYAFoW4AQAxT1jq6upo5MiRKlBEwoGmpKQktGzatKnNdW7ZsoWWLFlCK1eupCNHjqj1T5s2jc6fPy/tHgBYEOIGAMT8pNvp06erpS1ut5syMzO11/n888/TwoULacGCBer2+vXr6d1336UNGzbQT3/6U2kXAcBiEDcAwJLnsOzevZt69+5NQ4YMoQcffJAqKioitvV4PHT48GGaMmXKV52y29XtAwcOtPqcpqYmqq6uDlsAIL6ZHTcYYgdA/Gr3hIWHdV9//XXatWsXPfPMM7Rnzx61Z+WPcFlueXm5eqxPnz5h9/Pt0tLSVp+Tn59P6enpoaVv377t/TIAIIZiETcYYgdA/Gr3Oiz33ntv6P/Dhw+nESNG0KBBg9Te0+TJk9vld+Tl5alj10G8l4TAAxC/YhE3GGIHQPwy/bLmgQMHUq9evaigoKDVx/kxh8NB586dC7ufb0c6ns3HutPS0sIWAEgcZsQNhtgBEL9MT1iKi4vVseisrKxWH3e5XDR69Gg1FBwUCATU7ZycHLO7BwAWhLgBAF87YamtraWjR4+qhZ06dUr9v7CwUD32yCOP0MGDB+n06dMqeNx999103XXXqcsNg3iId82aNaHbPET7yiuv0GuvvUZffPGFOuGOL4MMnv0PAPENcQMAYn4Oy6FDh2jSpEmh28Hjwbm5ubRu3Tr685//rAJIZWWlKhI1depUevLJJ9VQbNDJkyfVSXNBc+bMobKyMlqxYoU6YW7UqFG0Y8eOy06oi8ZVUU9OR/Q5d/w9uojW67hYr93WSHZpt7XVNeh3wi+bw8HwCeYe8nr025o5X4xdMNeIQ38OH2ZzCT4XwbxD1L2bdlOHR7BeIkopE3zmgrfOlyx772r7Juut1xufccPMPTy7Tf8zdJCkrf526BD0wUzSuYT8knmKRHMJyfpBgvY2k+YHskvm+yFuL2jr1f9+2CR/V0xgMwzBDGsWxSfO8Rn/377pEXI6ok8wF+ii/8dLmrCQw26NhKVJkIQgYQknSFgCgoTF2zOFRCRvtYkJiyddr73P20iHf7uMqqqq4ubckGDsuPh/AyktNfq2e8pbK1p/kV//+3HW20O7bZlP//0t9+n3gVV49Ntf8Ojv/F1oku0oVjXpJcqspkG/bUOdLP4HavXjgbNKf9tyVelvtO5KWdxNqdBv36VEf8JeV1HkUgOt8Z0ujN7G8NJu2q4VNzCXEAAAAFgeEhYAAACwPCQsAAAAYHlIWAAAAMDykLAAAACA5SFhAQAAAMtDwgIAAACWh4QFAAAALA8JCwAAACReaX4rM5IcZGhUQLX5ZFVjJaX8HRfqtNsakhLw1bIKmxTQL6FsePXrOBs+n7Af5lTGtQkqCjOjUb+aoz0tVb8fHv33w3Ve/7vBAi79zdNzlX6lTylvil5FTr9DWPI8Dpn5EiVl/K0iICixbEhL8wvK7fv9gtL8PlnssPkE5fYFletFbYVh1y4o5S/6eyisuN7eMMICAAAAloeEBQAAACwPCQsAAABYHhIWAAAAsDwkLAAAAGB5SFgAAADA8pCwAAAAQOIlLHv37qUZM2ZQdnY22Ww22rZtW9jjfF9ry+rVqyOuc9WqVZe1Hzp06JW9IgCwHMQNAIh5wlJXV0cjR46ktWvXtvp4SUlJ2LJhwwYVSO6555421zts2LCw5+3bt0/aNQCwKMQNAIh5pdvp06erJZLMzMyw29u3b6dJkybRwIED2+6I03nZcwEgMSBuAIClS/OfO3eO3n33XXrttdeitj1x4oQaLk5OTqacnBzKz8+nfv36tdq2qalJLUHV1dXqZ2NGCjmTopcod1/QL9PO7PUe/cbO6FMDXAmbyyVqbzQ06q9bYzqD0Hqbve9a7SWl+QXTCRDJ3g+74P0zvF79FZdd0G5qc7tJJPsqU0pxV/eXbfZJDXrrtnkNS8eNtmKHWRwUiLsy/pJy+wFBuX3Jepk/oN/e7xesW9JWWJrfJghhZrVldsG2aJeU5jfMmWrFEifdcsBJTU2l7373u222Gzt2LG3cuJF27NhB69ato1OnTtGECROopqam1fYclNLT00NL3759TXoFABBrZsUNhtgBEL9MTVj4OPR9992n9n7awkPFs2fPphEjRtC0adPovffeo8rKSnrrrbdabZ+Xl0dVVVWhpaioyKRXAACxZlbcYIgdAPHLtENCf/zjH+n48eO0ZcsW8XO7d+9OgwcPpoKCglYfd7vdagGAxGJm3GCIHQDxy7QRll//+tc0evRodWWAVG1tLZ08eZKysrJM6RsAWBPiBgC0W8LCQeHo0aNqYXzcmP9fWFgYdiLb1q1b6YEHHmh1HZMnT6Y1a9aEbi9dupT27NlDp0+fpv3799OsWbPI4XDQ3Llzpd0DAAtC3ACAmB8SOnTokLrcMGjJkiXqZ25urjoBjm3evJkMw4gYOHgvqLy8PHS7uLhYta2oqKCMjAwaP348HTx4UP0fAOIf4gYAfF02gyNEnOM9Mz7jP2fq4x1+WbNNcomYgK2mXtTeqDLncs1Afb0lLmu2JQkva06J/r0IcSWRGaSXNfsFlzX7uur3uXKQ25TLmv3eRjr81jJ1MmtaWhrFU+y4+H8DKS01+oBzoa9WtP4iXxfttl/6emi3LfPpv7/nvbLPosyTqt22vKmrdtuLTfrvBbtQn6LdtqZOf/v21gpjR43+fn1Sjf4l0K4q/bbui7I/013K9GNpSmmDdltHiX4ZB+Yr/jJ6G8NLu2m7VtzAXEIAAABgeUhYAAAAwPKQsAAAAIDlIWEBAACAzj2XUKzZvQGyG9FPevV2k51UaU/Rf5tcpZHLgrdkaxTMUSSZ34bX3SNdu23gXJl2W8Mvm9TCLjjR1PDrz2lkc8hybX9tHZnB0U3/pEOjZ3fRum1e/ffaL3if7T5RN4h0z/eL+9P3O5bDpDfQLlyvtL0Z8w4xQ9LerLaqPZnUD/2m4qmjDEpIGGEBAAAAy0PCAgAAAJaHhAUAAAAsDwkLAAAAWB4SFgAAALA8JCwAAABgeUhYAAAAwPKQsAAAAIDlIWEBAAAAy0uISreGcamsn8/XZMr67X79soF2v34fbAFBpVtJW9Vev88BQ3/dAUNWcdcuqPwY/Bx12ITVKgOGtLyrHkPw3pHgu6HWLSgq7PPp73v4PbJqxeTV+1z83kbx59jRgn2trtUrJVrjk5UcrRO0r/fpfy4Nfv3vc5OwSrbHo9/e69H//vsaZX9u/A363+lAvf56Aw3CsrGN+hW4/Y12U8KB3yPbpnyCKtk+/6XtVocRkMUwn8bfCx95teNGQiQsNTWXyuF/9Mefd3RXoCX9bUFGFoPNU21S2zjfHtPT9aeGsELs6H/L6Y7uCkCnVqMRN2xGPO0ORRAIBOjs2bOUmppKNttXe97V1dXUt29fKioqorS0NEo0if76OsNrTKTXx6GEg052djbZ7fFxtBmxA68vXlUnyGuUxI2EGGHhF3nNNddEfJw/zHj+QKNJ9NfXGV5jory+eBlZCULswOuLd2kJ8Bp140Z87AYBAABAp4aEBQAAACwvoRMWt9tNK1euVD8TUaK/vs7wGhP99cWrRP9c8Prin7sTvMaEPOkWAAAAEltCj7AAAABAYkDCAgAAAJaHhAUAAAAsDwkLAAAAWB4SFgAAALC8hE5Y1q5dSwMGDKDk5GQaO3Ysffzxx5QIVq1apcqIN1+GDh1K8Wzv3r00Y8YMVZ6ZX8+2bdvCHueL2VasWEFZWVmUkpJCU6ZMoRMnTlCivL758+df9pnecccdHdbfzixR40Yixg7EjfmdKm4kbMKyZcsWWrJkibpO/ciRIzRy5EiaNm0anT9/nhLBsGHDqKSkJLTs27eP4lldXZ36jPiPRWueffZZeumll2j9+vX00UcfUdeuXdXn2dho1uyKsX19jANN889006ZNMe0jJH7cSLTYgbhBnStuGAlqzJgxxqJFi0K3/X6/kZ2dbeTn5xvxbuXKlcbIkSONRMVfy3feeSd0OxAIGJmZmcbq1atD91VWVhput9vYtGmTEe+vj+Xm5hp33313h/UJEj9uJHrsQNxIfAk5wuLxeOjw4cNq+K/5JGd8+8CBA5QIeFiThwkHDhxI9913HxUWFlKiOnXqFJWWloZ9njxZFg/XJ8rnyXbv3k29e/emIUOG0IMPPkgVFRUd3aVOpTPEjc4UOxA3Ek9CJizl5eXk9/upT58+Yffzbf4Cxzve4DZu3Eg7duygdevWqQ1zwoQJaoruRBT8zBL18wwO677++uu0a9cueuaZZ2jPnj00ffp09T2G2Ej0uNHZYgfiRuJxdnQHQI6/kEEjRoxQQah///701ltv0f3339+hfYMrc++994b+P3z4cPW5Dho0SO09TZ48uUP7BokDsSOx3NvJ4kZCjrD06tWLHA4HnTt3Lux+vp2ZmUmJpnv37jR48GAqKCigRBT8zDrL58l4uJ6/x4n6mVpRZ4sbiR47EDcST0ImLC6Xi0aPHq2GyYICgYC6nZOTQ4mmtraWTp48qS7dS0TXXnutCjDNP8/q6mp11n8ifp6suLhYHYtO1M/Uijpb3Ej02IG4kXgS9pAQX5qYm5tLt956K40ZM4ZefPFFdYnYggULKN4tXbpUXZvPQ7lnz55Vl2DynuHcuXMpngNn870CPrZ+9OhR6tmzJ/Xr148efvhheuqpp+j6669XgWj58uXqxMGZM2dSvL8+Xh5//HG65557VIDlPyA/+clP6LrrrlOXYELsJHLcSMTYgbjxeOeKG0YC++Uvf2n069fPcLlc6nLFgwcPGolgzpw5RlZWlnpdV199tbpdUFBgxLMPP/xQXbbXcuHL9oKXKC5fvtzo06ePuixx8uTJxvHjx41EeH319fXG1KlTjYyMDCMpKcno37+/sXDhQqO0tLSju90pJWrcSMTYgbgxtVPFDRv/09FJEwAAAECnO4cFAAAAEgsSFgAAALA8JCwAAABgeUhYAAAAwPKQsAAAAIDlIWEBAAAAy0PCAgAAAJaHhAUAAAAsDwkLAAAAWB4SFgAAALA8JCwAAABAVvf/AFzg6Qh9JoIaAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.subplot(1, 2, 1)\n", "plt.title('permeability')\n", "plt.imshow(k_train[0])\n", "plt.subplot(1, 2, 2)\n", "plt.title('field solution')\n", "plt.imshow(u_train[0])\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "89a77ff1", "metadata": {}, "source": [ "We now create the Neural Operators problem class. Learning Neural Operators is similar as learning in a supervised manner, therefore we will use `SupervisedProblem`." ] }, { "cell_type": "code", "execution_count": 4, "id": "8b27d283", "metadata": { "ExecuteTime": { "end_time": "2024-09-19T13:35:29.136572Z", "start_time": "2024-09-19T13:35:29.134124Z" } }, "outputs": [], "source": [ "# make problem\n", "problem = SupervisedProblem(\n", " input_=k_train.unsqueeze(-1), output_=u_train.unsqueeze(-1)\n", ")" ] }, { "cell_type": "markdown", "id": "1096cc20", "metadata": {}, "source": [ "## Solving the problem with a FeedForward Neural Network\n", "\n", "We will first solve the problem using a Feedforward neural network. We will use the `SupervisedSolver` for solving the problem, since we are training using supervised learning." ] }, { "cell_type": "code", "execution_count": 5, "id": "e34f18b0", "metadata": { "ExecuteTime": { "end_time": "2024-09-19T13:35:31.245429Z", "start_time": "2024-09-19T13:35:29.154937Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9: 100%|██████████| 100/100 [00:00<00:00, 289.72it/s, v_num=3, data_loss_step=0.102, train_loss_step=0.102, data_loss_epoch=0.105, train_loss_epoch=0.105] " ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=10` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9: 100%|██████████| 100/100 [00:00<00:00, 286.77it/s, v_num=3, data_loss_step=0.102, train_loss_step=0.102, data_loss_epoch=0.105, train_loss_epoch=0.105]\n" ] } ], "source": [ "# make model\n", "model = FeedForward(input_dimensions=1, output_dimensions=1)\n", "\n", "\n", "# make solver\n", "solver = SupervisedSolver(problem=problem, model=model, use_lt=False)\n", "\n", "# make the trainer and train\n", "trainer = Trainer(\n", " solver=solver,\n", " max_epochs=10,\n", " accelerator=\"cpu\",\n", " enable_model_summary=False,\n", " batch_size=10,\n", " train_size=1.0,\n", " val_size=0.0,\n", " test_size=0.0,\n", ")\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "7b2c35be", "metadata": {}, "source": [ "The final loss is pretty high... We can calculate the error by importing `LpLoss`." ] }, { "cell_type": "code", "execution_count": 6, "id": "0e2a6aa4", "metadata": { "ExecuteTime": { "end_time": "2024-09-19T13:35:31.295336Z", "start_time": "2024-09-19T13:35:31.256308Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Final error training 28.57%\n", "Final error testing 28.59%\n" ] } ], "source": [ "from pina.loss import LpLoss\n", "\n", "# make the metric\n", "metric_err = LpLoss(relative=False)\n", "\n", "model = solver.model\n", "err = (\n", " float(\n", " metric_err(u_train.unsqueeze(-1), model(k_train.unsqueeze(-1))).mean()\n", " )\n", " * 100\n", ")\n", "print(f'Final error training {err:.2f}%')\n", "\n", "err = (\n", " float(\n", " metric_err(u_test.unsqueeze(-1), model(k_test.unsqueeze(-1))).mean()\n", " )\n", " * 100\n", ")\n", "print(f\"Final error testing {err:.2f}%\")" ] }, { "cell_type": "markdown", "id": "6b5e5aa6", "metadata": {}, "source": [ "## Solving the problem with a Fourier Neural Operator (FNO)\n", "\n", "We will now move to solve the problem using a FNO. Since we are learning operator this approach is better suited, as we shall see." ] }, { "cell_type": "code", "execution_count": 7, "id": "9af523a5", "metadata": { "ExecuteTime": { "end_time": "2024-09-19T13:35:44.717807Z", "start_time": "2024-09-19T13:35:31.306689Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9: 100%|██████████| 100/100 [00:02<00:00, 36.66it/s, v_num=4, data_loss_step=0.00164, train_loss_step=0.00164, data_loss_epoch=0.00229, train_loss_epoch=0.00229]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=10` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9: 100%|██████████| 100/100 [00:02<00:00, 36.56it/s, v_num=4, data_loss_step=0.00164, train_loss_step=0.00164, data_loss_epoch=0.00229, train_loss_epoch=0.00229]\n" ] } ], "source": [ "# make model\n", "lifting_net = torch.nn.Linear(1, 24)\n", "projecting_net = torch.nn.Linear(24, 1)\n", "model = FNO(\n", " lifting_net=lifting_net,\n", " projecting_net=projecting_net,\n", " n_modes=8,\n", " dimensions=2,\n", " inner_size=24,\n", " padding=8,\n", ")\n", "\n", "\n", "# make solver\n", "solver = SupervisedSolver(problem=problem, model=model, use_lt=False)\n", "\n", "# make the trainer and train\n", "trainer = Trainer(\n", " solver=solver,\n", " max_epochs=10,\n", " accelerator=\"cpu\",\n", " enable_model_summary=False,\n", " batch_size=10,\n", " train_size=1.0,\n", " val_size=0.0,\n", " test_size=0.0,\n", ")\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "84964cb9", "metadata": {}, "source": [ "We can clearly see that the final loss is lower. Let's see in testing.. Notice that the number of parameters is way higher than a `FeedForward` network. We suggest to use GPU or TPU for a speed up in training, when many data samples are used." ] }, { "cell_type": "code", "execution_count": 8, "id": "58e2db89", "metadata": { "ExecuteTime": { "end_time": "2024-09-19T13:35:45.259819Z", "start_time": "2024-09-19T13:35:44.729042Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Final error training 3.36%\n", "Final error testing 3.54%\n" ] } ], "source": [ "model = solver.model\n", "err = (\n", " float(\n", " metric_err(u_train.unsqueeze(-1), model(k_train.unsqueeze(-1))).mean()\n", " )\n", " * 100\n", ")\n", "print(f\"Final error training {err:.2f}%\")\n", "\n", "err = (\n", " float(metric_err(u_test.unsqueeze(-1), model(k_test.unsqueeze(-1))).mean())\n", " * 100\n", ")\n", "print(f\"Final error testing {err:.2f}%\")" ] }, { "cell_type": "markdown", "id": "26e3a6e4", "metadata": {}, "source": [ "As we can see the loss is way lower!" ] }, { "cell_type": "markdown", "id": "ba1dfa4b", "metadata": {}, "source": [ "## What's next?\n", "\n", "We have made a very simple example on how to use the `FNO` for learning neural operator. Currently in **PINA** we implement 1D/2D/3D cases. We suggest to extend the tutorial using more complex problems and train for longer, to see the full potential of neural operators." ] } ], "metadata": { "kernelspec": { "display_name": "pina", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 5 }