Files
PINA/tutorials/tutorial5/tutorial.ipynb
2023-11-17 09:51:29 +01:00

416 lines
32 KiB
Plaintext
Vendored

{
"cells": [
{
"cell_type": "markdown",
"id": "e80567a6",
"metadata": {},
"source": [
"# Tutorial 5: Fourier Neural Operator Learning"
]
},
{
"cell_type": "markdown",
"id": "8762bbe5",
"metadata": {},
"source": [
"In this tutorial we are going to solve the Darcy flow 2d problem, presented in [Fourier Neural Operator for\n",
"Parametric Partial Differential Equation](https://openreview.net/pdf?id=c8P9NQVtmnO). First of all we import the modules needed for the tutorial. Importing `scipy` is needed for input output operation, run `pip install scipy` for installing it."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5f2744dc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/sissa/apps/intelpython/2022.0.2/intelpython/latest/lib/python3.9/site-packages/scipy/__init__.py:138: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.26.0)\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion} is required for this version of \"\n"
]
}
],
"source": [
"\n",
"from scipy import io\n",
"import torch\n",
"from pina.model import FNO, FeedForward # let's import some models\n",
"from pina import Condition\n",
"from pina import LabelTensor\n",
"from pina.solvers import SupervisedSolver\n",
"from pina.trainer import Trainer\n",
"from pina.problem import AbstractProblem\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "4cf5b181",
"metadata": {},
"source": [
"## Data Generation\n",
"\n",
"We will focus on solving the a specfic PDE, the **Darcy Flow** equation. The Darcy PDE is a second order, elliptic PDE with the following form:\n",
"\n",
"$$\n",
"-\\nabla\\cdot(k(x, y)\\nabla u(x, y)) = f(x) \\quad (x, y) \\in D.\n",
"$$\n",
"\n",
"Specifically, $u$ is the flow pressure, $k$ is the permeability field and $f$ is the forcing function. The Darcy flow can parameterize a variety of systems including flow through porous media, elastic materials and heat conduction. Here you will define the domain as a 2D unit square Dirichlet boundary conditions. The dataset is taken from the authors original reference.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2ffb8a4c",
"metadata": {},
"outputs": [],
"source": [
"# download the dataset\n",
"data = io.loadmat(\"Data_Darcy.mat\")\n",
"\n",
"# extract data\n",
"k_train = torch.tensor(data['k_train'], dtype=torch.float).unsqueeze(-1)\n",
"u_train = torch.tensor(data['u_train'], dtype=torch.float).unsqueeze(-1)\n",
"k_test = torch.tensor(data['k_test'], dtype=torch.float).unsqueeze(-1)\n",
"u_test= torch.tensor(data['u_test'], dtype=torch.float).unsqueeze(-1)\n",
"x = torch.tensor(data['x'], dtype=torch.float)[0]\n",
"y = torch.tensor(data['y'], dtype=torch.float)[0]"
]
},
{
"cell_type": "markdown",
"id": "9a9defd4",
"metadata": {},
"source": [
"Let's visualize some data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c8501b6f",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAEjCAYAAAARyVqhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA3QklEQVR4nO3dC3xU5Zk/8GdmkkzCJQHkEoJcBAUUuQgWGgpVCgVZF4VaiqxdwCrddXFXP3ywGj+IeKmpul5qYdF2i+haBewq7K4uW0SBsoAKSJVVKaFAEiGQBHInmdv5n+f1P2MmyWTeB+bMnJn8vp/PIczknZN3LueZ57znnOd1GCYCAAAAsDFnojsAAAAAEA0SFgAAALA9JCwAAABge0hYAAAAwPaQsAAAAIDtIWEBAAAA20PCAgAAALaHhAUAAABsDwkLAAAA2B4SFkhaK1euJIfDQRUVFVHbDho0iBYtWhS6vX37dvVY/hnEv+d2AKnu448/pokTJ1Lnzp3VdnDw4MHQ9nQhdLed48ePq7+xbt26C/o7F4L/Fv9N/tuxdP3116sF4gcJC0AEDQ0NKog3T2oAkp3X66W5c+fS2bNn6bnnnqN/+7d/o4EDBya6W7b0+eefqxgQ62QHLkzahT0MILkcPnyYnM728/Pf/OY3FAgEwhKWRx55RP0fe1KQKo4ePUonTpxQn/c777wzdP/y5cvpgQceSGDP7JmwcAzg7b/lCNIf/vCHBPWq40LCAlHxF3enTp2S+pVyu91R26Snp8ehJwCJdebMGfWzW7duYfenpaWpBfRkZGTgpYozHBJKcsHjzl9++SX96Ec/ouzsbLrkkkvonnvuocbGxrC2r732Go0bN46ysrKoR48edOutt1JJSUlYG96TuPrqq2n//v303e9+VyUqDz74YOjY8z//8z/T6tWrafDgwep306dPV+vgSb8fe+wxuvTSS9X6b775ZjXk3NJ///d/0+TJk9Wx865du9KNN95I//d//xfW5tNPP1XHxPlvZGZmUm5uLv3kJz+hysrKNl8DPocl2nNveQ5LtOPw/Hx79eql/s97WPzceeHX++WXX1b//+STT1qt44knniCXy0VfffVVu38LIBH4M37dddep//NhIf4cB0cPI53DohM32lJVVaX+Xk5OjkqOFi5cqO7TPWzF290VV1yhYgBv15MmTaKtW7eGtXv//fdD8YT/BsedL774Iur6g9tyS83jBJ/7wq8RmzJlSigGBA8Rt3UOCyeDd9xxB/Xp00f1e/To0fTKK6+EtWkeS3/961/TkCFD1A7Vt771LXVuEUSGdDpF8Bc2b2yFhYW0d+9eeuGFF+jcuXP06quvqt///Oc/p4ceeki142Hg8vJy+tWvfqWSEv7ibb63xYnBzJkzVWD68Y9/rDa+oN/97nfk8XjoH//xH1VC8tRTT6l1fu9731Mb8v33309FRUVq3cuWLaO1a9eGHsvHyjlozZgxg5588kk1crNmzRoViLgPwWSBg9Jf/vIXuv3221WywgkNb9j8k59by6Aa7blfCE5WuG933XUXzZkzh37wgx+o+0eNGkWXXXYZLVmyRL0W11xzTdjj+D4OYv369bvgvw1glb/7u79Tn01OrP/pn/5JfUk2375bksSN5ngHhpOHXbt20d///d/TlVdeSW+//bba/nVwMsHbM//N8ePHU01NDe3bt48OHDhA3//+91Wb9957T8Up3rHh9ufPn1d9+853vqPaXewJ9Pwc+TXieMI7bfwcWPBnS/z3edvn+Hf33XerOPHmm2+qBIgTNd6Rau7111+n2tpa9Z5wTONYynGGYx9GeyMwP1iQxB5++GGD38abbrop7P5/+Id/UPf/6U9/MsyM3jD3+g0z+IS1+eyzzwxzCDjsfnPvSz3uxRdfDGt77Ngxdb/5RW6YG1/o/oKCAnW/uSdhmHtFofvnz59vmEOmhjnSoW6bG6ZhBjdj8eLFYestKyszzD2wsPvNRKbV83zjjTfU39m5c6fouQeZJxUaZrAM3f7ggw9UG/4ZxL/ndkFmcFZt+O+0xM8vLy/P8Pv9ofvMIKnamyMwrdoD2EXws29+mYbdH9yegiRxo+W2s2nTJrUu80s4dJ/P5zPM0RCtbYTjiTn62m6bMWPGGL179zbMHazQfbzNm+eqGQsWLAjdx3+L/ybHsKBI23XLOMGvUcs40TxW8hL0/PPPq7bmiFToPnPnzsjPzze6dOlimElXWCw1R40Mc6cv1Hbz5s3q/v/8z/9s93l3ZDgklCJ4j785HgFh7777Lr311lvqZFLeS+LDJ8GFRy94yNXcGMMey8OTPLrRFh4i5SHeoAkTJqifPBLT/Pg3388jMcFDIzxqwnsZ5hd9WB/48Am3bd4HHnoO4kM73O7b3/62us17TpLnbhUzINLJkyfD+s2jK9z3W265xbK/CxAv0rjRHG97HA94hDKIt/XgthkNj9zwiOqRI0fa/P2pU6fUpdg8esGHqYJ4BJRHYKzc9iPhv8mvDce4IB4p4VGauro62rFjR1j7efPmUffu3UO3+dAW4xEWaBsOCaUIDiDN8XFRviqGj5fyTzM5bdUmqOXwIw8ZRzqhbMCAAWG3g8lL//7927yfD82wYODhQ0dt4fNPgvhQEx+/Xr9+fegEwaDq6mrRc7cKB8W+ffuqJGXq1KkqsJujQGoYnM/NAUh2vM1K4kZzfBUSbx/myELY/cOGDdP6248++qjaloYOHarOqbvhhhvob//2b1VCElx/pPXxIZv/+Z//ofr6enVuS7xwn/i1ank1YvAQUrDPkWJpMHkJxkxoDQlLimp+ngd/mfJtPuGV93JaahlUmo9wtNTW49u7nwNesA/B81h4L6Sl5qMzvEe3e/duuu+++8gc9lX948dz0Gp+2XEkF1r8SoKf79/8zd+oS0P/5V/+hf73f/9XjbjwSBNAKpDGjVji80f48mvzMIm6fPhf//VfVc0Y81B12KXYsWYe4rVs3dKYCa0hYUmhvSE+ySuIT/zigMMnnvGGwRsB/573WBKBRz2YecyZpk2bFrEd711s27ZNjbCsWLEidH+koeFoz/1iREt8+LDQM888Q+YxZxXU+URdPqEYIBXwNnuhcYML0fF2zIdCmic2XA9JFx/q4UPTvPB6OInhk2s5YQkWumtrfXzFZM+ePdsdXeHRjJZXLPEhbD7UdKE7P9wnvsKRY0/zURbuT/D3cHFwDkuK4EuNm+Oz5RmfRc9nnnPSwklAy+ydb0e6XDiW+IucD/vw1Ql8yWJLfPVB872Olv00T2i7oOd+MYK1ZyJdisnD07zw3t+///u/q6uqUMcCUsXFxI2/+qu/IvMkW3WlXfPRi+C2GU3LdXPSc/nll1NTU5O6zYebePSVLxluvn0eOnRIjcjw34+WjO3cuTPsPr4SseUISzDp0bkcm/9mWVkZbdiwIXQfvwb8nLn/wcvJ4cJhhCVFmGee00033aQOm+zZs0fVTuBDFlwHgD3++ONUUFCgzuuYPXu2Os+CH8OXGv70pz9VlyBbiZMVDl58HHrs2LHqy51HJIqLi+mdd95RlyKuWrVKteM9Kb7EjxMbPp+GAxD39UKf+4XiQ2NXXXWVCkC8h8l7fHw8nZfmoyzB1w6HgyCV8Jf6hcaNWbNmqW2aK+fyY3k74pN42zoHrS3cni8R5vovvN3xJc2///3v1eXCQU8//bTaKcnPz1e1T4KXNfP5c23VWGmOR2n4cms+QZ7PR/vTn/6kznvhkZnmOCnipI3LMHDf+YIEPg+PR4pb4tfjpZdeUicCcx0rHuHlPvPhYt7hwrltMZCgq5MgRoKXIn7++efGD3/4Q8PcKAxzuNMwN2zD3IDD2pqjAMakSZMMc69BLcOHDzeWLFlimMOqoTZ8md6IESNa/Z3gpXhmkNC6RDJ4KeHHH3/cqr052qIuZc7MzDTMoGiYG7hhBqRQm9LSUmPOnDnqMmhuN3fuXMM8P6TVpYiS534hlzUz81wawwya6hLtln+fmUPI6tJPM6Fp9ZoBJPNlzZK40da2w5cbmzsohrkTorZj/v8nn3yidVmzmSgZ48ePVzHA3HFQf5Mvo+bLhJt77733DDMxUm3475iJkooHzbV1WTOXI7j//vsNM0ExzJFUFZPMQ8mt4gQzz1MzBg8erLbz5jGj5WXN7PTp04Z5CEutl2PGyJEjWz3XSLGUtRVj4BsO/icGeQ8kCO9J8JAtH1JpuXcA1uPLPHl4ms+34QJbAABgDZzDAnARuHw3H/fmQ10AAGAdnMMCcAF4DhOeyZVLl/Ox/Yu9IgkAANqHhAXgAnBhK64VwycW6l75AAAAFw7nsAAAAIDt4RwWAAAAsD0kLAAAAGB7KXEOC5dC5nlcuDBPPOaRAYDWuEJCbW0t5eXltZoAzq4QOwCSJ26kRMLCyUrL2YIBIDFKSkro0ksvTYqXH7EDIHnihmUJC8/vwqWTeW4FLpHOV1KMHz8+Yvs333xTFd7iMs48RTeXQo42H0RQsOTxpQ8vJ2dmZtT2jugT/obpXCzYWxQM8NRc2XpOnUjcpyJP5X6xOp3Rrx2YXierM5hVrv8cyan/4gVcspG0+lz9j3rmOf0PSNr5gOQMd+22LKPq63lTdATS2575tS1pReETvEXjr6jQaucjL+2idy+qBHk84wYL9vXEgUGU3SX6dl7ur9deN6vw68eOs4HIs6S3dM4feWK/lqoF61Xr9unPwlzp1e9HhUc2u3NFk377ivNfz/ul42ytflvmqY7+nRKUVqUfZ9yV+jEss1IYdyv0Z53OOt2g3dZ16qyoH76y0zGNG5YkLDz3ytKlS9VU4BMmTFDzKPDkdzyzZltzMPDlofPnz6fCwkL667/+a3r99ddVbYsDBw6EzdsSSfAwECcrViQsLrc1CYszS/9LxuW2LmFxqarzsW/L0tJc1iQsabKExZWh/1FPSxckLD5BwhIQvnaCpCwgeJ3TnBmifjgcmp+9///0LvSwbLzjRvO+crKS3TX6dt4oSECk7ZsCgrZ+/ffb45eF+fM+wZeuVz8uZXhkn7v0NP32aU63dluXXz8BYU6Pfntno/5r53I7rIu76foJS5pLv61LGDtIJ3YI4oYlB5qfffZZWrx4sZoWnCex4gDEM9+uXbu2zfa//OUv1cR19913H1155ZX02GOPqQnyeDI8AOgYEDcAIK4Ji8fjUTNVTps27Zs/Yp5Iw7d5Jt228P3N2zPes4rUnqcYr6mpCVsAIHnFI24wxA6A5OW0YjI4nlulT58+YffzbT4u3Ra+X9Keh4B5CvHgghNuAZJbPOIGQ+wASF7Jce1hCwUFBVRdXR1a+OxiAADEDoDUFfOTbnv27Ekul4tOnw4/O5hv5+bmtvkYvl/S3u12qwUAUkM84gZD7ABIXjEfYcnIyKBx48bRtm3bwooz8e38/Pw2H8P3N2/Ptm7dGrE9AKQWxA0ASMhlzXxp4sKFC+naa69VNRT48sT6+np11RBbsGAB9evXTx1PZvfccw9dd9119Mwzz9CNN95I69evp3379tGvf/1rK7oHADaEuAEAcU9Y5s2bR+Xl5bRixQp1AtyYMWNoy5YtoRPkiouLw0rwTpw4UdVQWL58OT344IOqANSmTZu0aykEpZ13kNOIfi13Wp2sTkTtEP1aG5n9a/VXXKlfwKgxz6e/XlOnYkHdEf26QRRIF752A/Sv29d4675pKyjvIq29U32Z/sodgo50KdWvd8CauukX5ep6VP9zZ/TrJeqHQ/MqPIdhbtP6te5sEzcA2mXVbC92mUXGISkeltjTXh0GF/JPcnxZM18tNHjFz7UKx0kTlsY+1iQsDYKEhQIOyxKWLiWGZYkCf4clW8LiydbviEOw9UgTFkkgkSQsDr9skze+OKrVzmd46YOmjepE+OzsbNHfSHTsOPfnwVqF484IK92WCwrHVQoq0p7161eBrfLLKrueFVS6rfDqty33yCoglzcK1n1eP7mvrNFvy5oklW7PCipqSyrdlsu22U7lgkq3ZRZWuv3qpFbc2E6bteJGUl4lBAAAAB0LEhYAAACwPSQsAAAAYHtIWAAAAMD2kLAAAACA7SFhAQAAANtDwgIAAAC2h4QFAAAAbA8JCwAAANgeEhYAAADomHMJJYqnl4+cWdHn3PH0kZU5vnJI9PLCQaXVOdpts3vXabdt+rSbdltpSfy6/volop2yKY1EAoJy+wH9KYoUf6Zg+gFBvX3JFAGBNNl8Ap1O6/ejsY9++fVOR8+J+uHopveZdgY8RGdEq4Y48EuCAbcXTHLjF+zzBiRzb9hk6pyvHyD4vnBYNG2JdGjBIWkraZxYGGEBAAAA20PCAgAAALaHhAUAAABsDwkLAAAA2B4SFgAAALA9JCwAAABge0hYAAAAoOMlLIWFhfStb32LunbtSr1796bZs2fT4cOH233MunXryOFwhC2ZmZmx7hoA2BTiBgDEPWHZsWMHLVmyhPbu3Utbt24lr9dL06dPp/r6+nYfl52dTadOnQotJ06ciHXXAMCmEDcAIO6Vbrds2dJq9IRHWvbv30/f/e53Iz6OR1Vyc3Nj3R0ASAKIGwCQ8NL81dXV6mePHj3abVdXV0cDBw6kQCBAY8eOpSeeeIJGjBjRZtumpia1BNXU1Kifjgy/WqLp0/vrPulq8uu/TIO665c9P3RwkHZbRxfZdAJE+u2dHv3SzA4ry+1nBbTbGlnR3+fmXBpTNgSlpemv29OgP0dAfaZsc3N69F+8zHP674w/J0vUD1el5mc6IHtP4h032osdqUxSPp8FDHuU2w+Io40eh6TUvnqAflPDKZjWw+mw7FiIIYi7RhKV8bf0pFsOIvfeey995zvfoauvvjpiu2HDhtHatWtp8+bN9Nprr6nHTZw4kUpLSyMe787JyQkt/fv3t+opAECcWRU3GGIHQPKyNGHhc1kOHTpE69evb7ddfn4+LViwgMaMGUPXXXcdvfXWW9SrVy966aWX2mxfUFCg9sCCS0lJiRXdB4AEsCpuMMQOgORl2SGhu+++m/7rv/6Ldu7cSZdeeqnosenp6XTNNddQUVFRm793u91qAYDUYmXcYIgdAMkr5iMshmGooPP222/T+++/T5dddpl4HX6/nz777DPq27dvrLsHADaEuAEAcR9h4eHc119/XR1X5losZWVl6n4+1yQr6+uT/XgYt1+/fup4Mnv00Ufp29/+Nl1++eVUVVVFTz/9tLqs+c4774x19wDAhhA3ACDuCcuaNWvUz+uvvz7s/pdffpkWLVqk/l9cXExO5zeDO+fOnaPFixer5KZ79+40btw42r17N1111VWx7h4A2BDiBgDEPWHhod1otm/fHnb7ueeeUwsAdEyIGwAQDeYSAgAAANtDwgIAAAC2h4QFAAAAbM/y0vzxlJHlJVen6DWJ+3etEq3XKSjl/OFnl+uv2C0oRe8SlpP26ZdQDkgqtafJ+uEUlMTv0qVRu22vLu1PptmqfVadqL2usvps7bYnSnqK1u3ppl9fO5AumF7Bq/+5g/iRlMT3CGqvS8vn+y0qzW9VqX0mqy4vi2EOQewVlcSXTFsi/KYOuPRfEcMlGLdI5dL8AAAAALGAhAUAAABsDwkLAAAA2B4SFgAAALA9JCwAAABge0hYAAAAwPaQsAAAAIDtIWEBAAAA20PCAgAAALaHhAUAAABsL6VK8xsBBwXMJZrDFb1F660tzbamdH26oES6xvMKk6G/bmeGX7ttZiePqBu9uuqX0B+SXaHddniXU6J+5KXrT8dQH3Brt/3QNVi7bVnXrtptmSOQod3WLyjN76xvlPUjTS9MOAIo+d+SX1A0XtJWUsbfa8jCvFdQM94v2Of1C6cIsIq4urxTUJpfEP8l5fYDabJOG6LS/IJ1O1GaHwAAAKBdOCQEAAAAHS9hWblypTnk5ghbhg8f3u5j3nzzTdUmMzOTRo4cSe+++26suwUANoa4AQAJGWEZMWIEnTp1KrTs2rUrYtvdu3fT/Pnz6Y477qBPPvmEZs+erZZDhw5Z0TUAsCnEDQCIe8KSZp6kl5ubG1p69uwZse0vf/lLuuGGG+i+++6jK6+8kh577DEaO3YsrVq1yoquAYBNIW4AQNwTliNHjlBeXh4NHjyYbrvtNiouLo7Yds+ePTRt2rSw+2bMmKHuj6SpqYlqamrCFgBIblbHDYbYAZC8Yp6wTJgwgdatW0dbtmyhNWvW0LFjx2jy5MlUW1vbZvuysjLq06dP2H18m++PpLCwkHJyckJL//79Y/ocACC+4hE3GGIHQPKKecIyc+ZMmjt3Lo0aNUrt8fAJtFVVVbRx48aY/Y2CggKqrq4OLSUlJTFbNwDEXzziBkPsAEhelheO69atGw0dOpSKiora/D2f43L69Omw+/g23x+J2+1WCwCkJiviBkPsAEheltdhqauro6NHj1Lfvn3b/H1+fj5t27Yt7L6tW7eq+wGgY0LcAADLE5Zly5bRjh076Pjx4+qS5Tlz5pDL5VKXLrMFCxaoYdmge+65Rx23fuaZZ+jLL79U9Rj27dtHd999d6y7BgA2hbgBAHE/JFRaWqqSk8rKSurVqxdNmjSJ9u7dq/7P+Mx/p/ObPGnixIn0+uuv0/Lly+nBBx+kK664gjZt2kRXX321+G8bxtdLNLUnckTrdXr12wYyBSv2uSyZo0Jx6bdPE8wl1L3zeVE3hnULH7ZvT372Ue22YzMjX0HSll5On3bbEr/+4cbipku02zoFc5IwydQrLo/w8yEQqKrWa2fI5pmyS9ywUkDwJkrm/PGI5vuRzf8SELT3BVyWzH/0dXtr5q1xOmVzXjkE221AEHclMd0QziUUEHyzG2mC96XZNpgSCcv69evb/f327dtb3ccn2/ECAB0T4gYARIO5hAAAAMD2kLAAAACA7TkT3QEAAACAaJCwAAAAgO0hYQEAAADbQ8ICAAAAtoeEBQAAAGwPCQsAAADYHhIWAAAAsD3LZ2uOp4xPupDLHb02fnq6bL0BwcTQRoOgjLOgH4EMWel1v8OaUu2d02Xl13PdNdptB2ec0W8r/OR2cXbRbttg1Gm3dTn0y3z7/bL9gzTBlBCOgKDMd6ZsAwh49DoSMPSnP0hWfp25P5rxCPYJJeX2JWX8vYL1qvaScvuCMv7SUvtWleaXkpTmlwwBCF5mUan9r9vrv3aGS/A6OxL7nmCEBQAAAGwPCQsAAADYHhIWAAAAsD0kLAAAAGB7SFgAAADA9pCwAAAAgO0hYQEAAICOl7AMGjSIHA5Hq2XJkiVttl+3bl2rtpmZ0WupAEBqQewAgLgWjvv444/J7/eHbh86dIi+//3v09y5cyM+Jjs7mw4fPhy6zUkLAHQsiB0AENeEpVevXmG3f/GLX9CQIUPouuuui/gYTlByc3Nj3RUASCKIHQCQsNL8Ho+HXnvtNVq6dGm7oyZ1dXU0cOBACgQCNHbsWHriiSdoxIgREds3NTWpJaim5uvy71yBWqcKtbtKVl7bcDosKYksKfnvy5KNOkmqWnsz9T8Gdd4MUT8a/PrtGwQvSINRL+oHBRq1m5b7s7TbVnr0S/77PLIS6RmCSvdOr/5n2nFeNr2CM0OvlL+Ty9brv8y2ih26vhk71hMw9I+6BwRH6P2Ctl5hXXdJKX+/INBIyvhbSdoLUXtBGX/D0tL8ZFEZ/8Se9mrpX9+0aRNVVVXRokWLIrYZNmwYrV27ljZv3qwCFAeeiRMnUmlpacTHFBYWUk5OTmjp37+/Fd0HgARB7ACAuCYsv/3tb2nmzJmUl5cXsU1+fj4tWLCAxowZow4bvfXWW2po+KWXXor4mIKCAqqurg4tJSUlVnQfABIEsQMA4nZI6MSJE/Tee++pBEQiPT2drrnmGioqKorYxu12qwUAUg9iBwDEdYTl5Zdfpt69e9ONN94oehxfYfTZZ59R3759LeoZANgZYgcAxC1h4fNQOOgsXLiQ0tLCB3H48A8f0gl69NFH6Q9/+AP95S9/oQMHDtCPf/xjtYd15513WtE1ALAxxA4AiOshIT4UVFxcTD/5yU9a/Y7vdzq/yZPOnTtHixcvprKyMurevTuNGzeOdu/eTVdddZUVXQMAG0PsAIC4JizTp08ngy9xbMP27dvDbj/33HNqAQBA7ACASDCXEAAAANgeEhYAAACwPSQsAAAAYHtIWAAAAKBjzyUUb9nH/ZSWLp3tI8ZzCelNu6L43PrrdXaRzYDhCOjnoo0u/fl+yt36c+ewLzL1J7XMSTuv3dYjmYjDlOnwarf9vKmfdtui2p7abQP1gg+HySmZ8kcyPVazq/R0ODL1ijQ6eF6ZGM0lZFeCKZsUv2AmGslnWjLfj3QOn4BgfiBfQNAPyQRnwvbCt0XE4RCsXfIUResVziXntOb7TdqPWMMICwAAANgeEhYAAACwPSQsAAAAYHtIWAAAAMD2kLAAAACA7SFhAQAAANtDwgIAAAC2h4QFAAAAbA8JCwAAANgeEhYAAACwvZQqzZ95zkdpab6o7QJpwjL3ggrKknX7M/XzxbQmWZ+dXkEuKijj32R0EvXjy0Af7bY1TZnabT/v0lfUjyyXfmn+041dtdser+ih3TatRjadgMtjWPIZJUkpbpauOXVDQLbajsAv2CcMCOqpS9r6JXXapX0W1KKXluY3hO211yttb1E/JGX8pV0wJG+5aDoBlOYHAAAAiO0hoZ07d9KsWbMoLy/PTLYctGnTprDfG4ZBK1asoL59+1JWVhZNmzaNjhw5EnW9q1evpkGDBlFmZiZNmDCBPvroI2nXAMCmEDcAIO4JS319PY0ePVolGG156qmn6IUXXqAXX3yRPvzwQ+rcuTPNmDGDGhsjT+O6YcMGWrp0KT388MN04MABtX5+zJkzZ6TdAwAbQtwAgLgnLDNnzqTHH3+c5syZ0+p3PLry/PPP0/Lly+nmm2+mUaNG0auvvkonT55sNRLT3LPPPkuLFy+m22+/na666iqV7HTq1InWrl0r7R4A2BDiBgDY6iqhY8eOUVlZmToMFJSTk6MO8ezZs6fNx3g8Htq/f3/YY5xOp7od6TFNTU1UU1MTtgBAcopX3GCIHQDJK6YJCwcd1qdP+JUhfDv4u5YqKirI7/eLHlNYWKgCWnDp379/DHoPAIkQr7jBEDsAkldS1mEpKCig6urq0FJSUpLoLgFAEkDsAEheMU1YcnNz1c/Tp0+H3c+3g79rqWfPnuRyuUSPcbvdlJ2dHbYAQHKKV9xgiB0AySumCctll12mgsW2bdtC9/H5JXy1UH5+fpuPycjIoHHjxoU9JhAIqNuRHgMAqQNxAwAsqXRbV1dHRUVFYSfMHTx4kHr06EEDBgyge++9V11FdMUVV6hA9NBDD6maLbNnzw49ZurUqeoqo7vvvlvd5kuaFy5cSNdeey2NHz9eXWnEl0HyVUMAkPwQNwAg7gnLvn37aMqUKaHbnGwwTjjWrVtHP/vZz1Sy8dOf/pSqqqpo0qRJtGXLFlUQLujo0aPqpLmgefPmUXl5uSo4xyfMjRkzRj2m5Ql10aRXN1GaRvXzQGa6aL0Ov37NccMlKGt9Xr+tq1Fa1l3SVr/csssjG5RrasjSbltSpVkC3lTaubuoH850/fcw4BG81jX6m1BmjXB6BUFpfknpbiNd9llydtGbjsERMNf7zWadNHFDQjr7gKiEvmDA2y8piS+qvS4voW/Veq0q+y8ttS8q5W9YOEeAgGiqjiTiMLh4SpLjw058tdCUax4wExZ38iQsGfptfVmyLxlvF/32Tdn6G3BTd9nG3tRd/+Pl7ebXbuvoHH3OKNslLGdk72GnU/qvXZeT+q9H1le1on44a89rtfMFmui946vUifDJcl5ZMHac+/Ngyu4afXs86q0Trb/Ep/86fOXTT8LLffrzXZ3zddZuq9p79ecLq/Xpz/9V7dFvq9bt1W9f59Hf2alrjP4d0dz58/rr9tfpf7e4BHOLZZyT7Sh2OiOIHV/px45OR8+J+uE//M3RmEh8hpe202atuJGUVwkBAABAx4KEBQAAAGwPCQsAAADYHhIWAAAAsD0kLAAAAGB7SFgAAADA9pCwAAAAgO0hYQEAAADbQ8ICAAAAtoeEBQAAAFJvLqFU4KpuFD7Amrk1nII5XZxNwrmEmvTf2rTz+utOr5fluO4q/dfO20W/z74s2Uc3kGHNHB9OwZxN6bKq7qL5oCRzh/i7yEqTk1PvPQ/4O2Q4sT2/YD4j8fxHgnl5JOu1cn4g6VxCoom6AvptHaK2+l2w05xGsYYRFgAAALA9JCwAAABge0hYAAAAwPaQsAAAAIDtIWEBAAAA20PCAgAAAKmXsOzcuZNmzZpFeXl55HA4aNOmTaHfeb1euv/++2nkyJHUuXNn1WbBggV08uTJdte5cuVKta7my/Dhw+XPBgBsCXEDAOKesNTX19Po0aNp9erVrX7X0NBABw4coIceekj9fOutt+jw4cN00003RV3viBEj6NSpU6Fl165d0q4BgE0hbgDAxRJXepo5c6Za2pKTk0Nbt24Nu2/VqlU0fvx4Ki4upgEDBkTuSFoa5ebmSrsDAEkAcQMAbH8OS3V1tTrE061bt3bbHTlyRB1CGjx4MN12220qwYmkqamJampqwhYASB1WxA2G2AGQvCytpd3Y2KjOaZk/fz5lZ2dHbDdhwgRat24dDRs2TB0OeuSRR2jy5Ml06NAh6tq1a6v2hYWFqk1LzgYPOXXK6JuBUERaFlmTw+PTb9soK83vPK+/7rR6QWn+WtlHxp+pv25flqAkuFv2HvozBCXEZS+1Nqf+W6KkN+h/8Lyd9V+7QJqsNH+GU++18/kCto4b7cUOuDCScvsBclhWmt8fEPRDUBJftfdLSvOTJeX2nX79tmrdfmum9SDNWJB0Iyx8Au6PfvQjMgyD1qxZE3W4eO7cuTRq1CiaMWMGvfvuu1RVVUUbN25ss31BQYHaAwsuJSUlVjwFAIgzK+MGQ+wASF5pVgadEydO0Pvvv9/uXlJbeBh46NChVFRU1Obv3W63WgAgdVgdNxhiB0DycloVdPjY8nvvvUeXXHKJeB11dXV09OhR6tu3b6y7BwA2hLgBADFPWDiZOHjwoFrYsWPH1P/5ZDcOOj/84Q9p37599Lvf/Y78fj+VlZWpxePxhNYxdepUdfVQ0LJly2jHjh10/Phx2r17N82ZM4dcLpc6hg0AyQ9xAwDifkiIk5EpU6aEbi9dulT9XLhwoSoA9x//8R/q9pgxY8Ie98EHH9D111+v/s+jJxUVFaHflZaWquSksrKSevXqRZMmTaK9e/eq/wNA8kPcAIC4JyycdPAJcZG097sgHklpbv369dJuAEASQdwAgIuFuYQAAADA9pCwAAAAgO0hYQEAAADbQ8ICAAAAHbs0f7w5PF5yOKPnYEaWsOicuV79TghKF2ucoBzkbBL0gVctKaFcr5+3utJkdeuNdP32krYBt+yjG8gQlP0XtDXSrCtV7e2s/3r40/XXm9YoLKGv+5mWTnkBHYqk1D4zRKX5BW39sv10Q1Ca3+ETtJV8rfismwbEERDU5vdbNE+NJoywAAAAgO0hYQEAAADbQ8ICAAAAtoeEBQAAAGwPCQsAAADYHhIWAAAAsD0kLAAAAGB7SFgAAADA9pCwAAAAgO0hYQEAAADbS6nS/I0DelBaWmbUdu6vqmUrztCve+443yRbty6/X9TccV5Wyt8qOlMlhAjK/jvT0yx7D40M/XUHMvT77OkunBJCMHWDu1pQXltIt3S3qMQ3dLhy+9LS/L6AfuyQtA0IS/OTV7+906v/HGVttZt+3d6vvy06fILtNoDS/AAAAADtJ2Lt/hYAAAAgGROWnTt30qxZsygvL48cDgdt2rQp7PeLFi1S9zdfbrjhhqjrXb16NQ0aNIgyMzNpwoQJ9NFHH0m7BgA2hbgBAHFPWOrr62n06NEqwYiEE5RTp06FljfeeKPddW7YsIGWLl1KDz/8MB04cECtf8aMGXTmzBlp9wDAhhA3ACDuJ93OnDlTLe1xu92Um5urvc5nn32WFi9eTLfffru6/eKLL9I777xDa9eupQceeEDaRQCwGcQNALDlOSzbt2+n3r1707Bhw+iuu+6iysrKiG09Hg/t37+fpk2b9k2nzCtL+PaePXvafExTUxPV1NSELQCQ3KyOGwyxAyB5xTxh4cNBr776Km3bto2efPJJ2rFjh9q78ke4LLeiokL9rk+fPmH38+2ysrI2H1NYWEg5OTmhpX///rF+GgAQR/GIGwyxAyB5xbwOy6233hr6/8iRI2nUqFE0ZMgQtfc0derUmPyNgoICdc5LEI+wIGkBSF7xiBsMsQMgeVl+WfPgwYOpZ8+eVFRU1Obv+Xcul4tOnz4ddj/fjnQeDJ8jk52dHbYAQOqwIm4wxA6A5GV5wlJaWqqORfft27fN32dkZNC4cePUUHBQIBBQt/Pz863uHgDYEOIGAFx0wlJXV0cHDx5UCzt27Jj6f3FxsfrdfffdR3v37qXjx4+rpOPmm2+myy+/XF2mHMRDvKtWrQrd5sM7v/nNb+iVV16hL774Qp1wx5dBBq8aAoDkhrgBAHE/h2Xfvn00ZcqU0O3guSQLFy6kNWvW0KeffqoSj6qqKlVcbvr06fTYY4+podigo0ePqpPmgubNm0fl5eW0YsUKdcLcmDFjaMuWLa1OqIsmo7KB0lzR59zxd+8kWq/rXIN2WyMzQ7uto/68fif8sjkcDJ9g7iGvR7+tlfPFOAVzjZiHAyQc5kieflv9eYeoWxftpi6PYL2mrHLBey546XyZsteurn+m3nq9yRk3rNzDczr030MXSdrqb4cuQR+sJJ1LyC+Zp0g0l5CsHyRo77BofiCnZL4f1d78R5PTq//5cEi+VyzgMEwJ7UEM8Em3fLXQ966+z0xYok8wF+ik/+UlTVjI5bRHwtIkSEKQsIQTJCwBQcLi7ZGl3VaRbJkWJiyeHL32Pm8j7f/9cqqurk6a88qCsePcnwdTdtfo2+4xb51o/SV+/c/HSW937bblPv3Xt8Kn3wdW6dFvf9ajv/N3tkm2o1jdpJcos9rz+m3P18vif6BOPx6kVetvWxnV+hutu0r2NZ1Vqd++0yn9CXszSiKXGmiL73hx9DaGl7bTZq24gbmEAAAAwPaQsAAAAIDtIWEBAAAA20PCAgAAALaHhAUAAABsDwkLAAAA2B4SFgAAALA9JCwAAABge0hYAAAAIPVK89uZke4iQ6Nku8MnqxorKeXvOluv3daQlICvkVXYpIB+CWXDq1/H2fD5hP2wppCyQ1BRmBmN+tUcndld9fvh0X89Ms7ofzZYIEN/8/Rcol/pU8qbpVeR0+8SljxPQlY+RUkZf7sICEosG9LS/IJy+36/oDS/TxY7HD5BuX1B5XpRW2HYdQpK+Yu+D4UV12MNIywAAABge0hYAAAAwPaQsAAAAIDtIWEBAAAA20PCAgAAALaHhAUAAABsDwkLAAAApF7CsnPnTpo1axbl5eWRw+GgTZs2hf2e72trefrppyOuc+XKla3aDx8+XP5sAMCWEDcAIO4JS319PY0ePZpWr17d5u9PnToVtqxdu1YlILfccku76x0xYkTY43bt2iXtGgDYFOIGAMS90u3MmTPVEklubm7Y7c2bN9OUKVNo8ODB7XckLa3VYwEgNSBuAICtS/OfPn2a3nnnHXrllVeitj1y5Ig6zJSZmUn5+flUWFhIAwYMaLNtU1OTWoJqamrUz8ZeWZSWHr1Eufusfpl25mzw6DdOiz41wIVwZGSI2hvnG/XXrTGdQWi9zV53rfaS0vyC6QSIZK+HU/D6GV6v/orLz2o3dbjd+utleZdYUoq7ZqBss08/r7duh9ewddxoL3ZYxUWBpCvjLym3HxCU25esl/kD+u39fsG6JW2FpfkdghBmVVvmFGyLTklpfsOaqVZscdItB5yuXbvSD37wg3bbTZgwgdatW0dbtmyhNWvW0LFjx2jy5MlUW1vbZnsOSjk5OaGlf//+VnQfABLgFYviBkPsAEheliYsfP7KbbfdpvZ+og0Xz507l0aNGkUzZsygd999l6qqqmjjxo1tti8oKKDq6urQUlJSYkX3ASABrIobDLEDIHlZdkjoj3/8Ix0+fJg2bNggfmy3bt1o6NChVFRU1Obv3ebQOi8AkFqsjBsMsQMgeVk2wvLb3/6Wxo0bp64okqqrq6OjR49S3759LegZANgV4gYAxCxh4WTi4MGDamF83Jj/X1xcHHYi25tvvkl33nlnm+uYOnUqrVq1KnR72bJltGPHDjp+/Djt3r2b5syZQy7zRND58+dLuwcANoS4AQBxPyS0b98+dZly0NKlS9XPhQsXqhPg2Pr1682TiY2ICQePnlRUVIRul5aWqraVlZXUq1cvmjRpEu3du1f9HwCSH+IGAFwsh5lYJPY6pRjgER2+Wih/+iMJv6zZIblETMBR2yBqb1Rbc7lmoKHBFpc1O9KFlzVnRf9chGSki9Zt1WXNfsFlzb7O+n2uGuK25LJmv7eR9m9crk6Ez87OFv2NRMeOc38eTNldow84F/vqROsv8XXSbvuVr7t223Kf/ut7xit7L8o9XbXbVjR11m57rkn/tWBnG7K029bW62/f3jph7KjV369Pr9W/BDqjWr+t+5zsa7pTuX4szSo7r93WdUq/jAPzlX4VvY3hpe20WStuYC4hAAAAsD0kLAAAAGB7SFgAAADA9pCwAAAAQMeeSyjenN4AOY3oJ716u8hOqnRm6b9MGWWRy4K35GgUzFEkmd+G1909R7tt4HS5dlvDL5vUwik40dTw689p5HDJcm1/Xb2ovS5XF/2TDo0e3UTrdnj1X2u/4HV2+kTdMDse43bQJpdFL6BTuF5peyvmHWKGpL1VbVV7sqgf+k3FU0cZwvZJAiMsAAAAYHtIWAAAAMD2kLAAAACA7SFhAQAAANtDwgIAAAC2h4QFAAAAbA8JCwAAANgeEhYAAACwPSQsAAAAYHspUenWML4u6+fzNVmyfqdfv2yg06/fB0dAUOlW0la11+9zwNBfd8CQVdx1Cio/Bt9HHQ5htcqAIS3vqscQvHYk+GyodQuKCvt8+vsefo+sWjF59d4Xv7dR/D4mWrCvNXV6pURrfbKSo/WC9g0+/fflvF//89wkrJLt8ei393r0P/++RtnXjf+8/mc60KC/3sB5YdnYRv0K3P5GpyXhwO+RbVM+QZVsn//r7VaHEZDFMJ/G94WPvNpxIyUSltrar8vhf/jHXyS4J9CK/rYgI4vB1qmxqG2Sb485OfpTQ9ghdgwcezzBPQHo2Go14obDSKbdoQgCgQCdPHmSunbtSg7HN3veNTU11L9/fyopKaHs7OwE9tAaqf78WKo/x1R6fhxKOOjk5eWR05kcR5sRO5L/c5fq21WqP0dDEDdSYoSFn+Sll14a8ff8ZibzGxpNqj8/lurPMVWeX7KMrAQhdqTG5y7Vt6tUf445mnEjOXaDAAAAoENDwgIAAAC2l9IJi9vtpocfflj9TEWp/vxYqj/HVH9+ySrV3xc8v+TnTvHPaMqedAsAAACpLaVHWAAAACA1OBPdAQAAAIBokLAAAACA7SFhAQAAANtDwgIAAAC2l9IJy+rVq2nQoEGUmZlJEyZMoI8++ijRXYqJlStXqikImi/Dhw9PdLcuys6dO2nWrFmqPDM/n02bNoX9ni9mW7FiBfXt25eysrJo2rRpdOTIkQT1NvbPb9GiRa3e0xtuuCFBve3YUjVupGLsQNxY1KHiRsomLBs2bKClS5eq69QPHDhAo0ePphkzZtCZM2cS3bWYGDFiBJ06dSq07Nq1K9Fduij19fXqPeIvi7Y89dRT9MILL9CLL75IH374IXXu3Fm9n42NVs2uGN/nxzjQNH9P33jjjTj2EDpC3Ei12IG4QR0rbnAdllQ0fvx4Y8mSJaHbfr/fMPdujcLCwgT2KjbMYGqYgTTR3bAMfyzffvvt0O1AIGDk5uYaTz/9dOi+qqoqw+12G+bGmYguxvT5sYULFxo333xzgnoEHSFupHrsQNxIfSk5wuLxeGj//v3qsEHzSc749p49exLYs9jhwyF8eGHw4MF02223UXFxcaK7ZJljx45RWVlZ2PvJk2XxcH2qvJ9s+/bt1Lt3bxo2bBjdddddVFlZmegudSgdIW50pNiBuJF6UjJhqaioIHPPiPr06RN2P9/mL75kx1/U69atoy1bttCaNWvUhjl58mQ1RXcqCr5nqfp+Bod1X331Vdq2bRs9+eSTtGPHDpo5c6b6HEN8pHrc6GixA3Ej9aQlugMgx19kQaNGjVJBaODAgbRx40a644478JImoVtvvTX0/5EjR6r3dciQIWrUZerUqQnsGaQSxI7UcmsHixspOcLSs2dPcrlcdPr06bD7+bZ5LkSCemWdbt260dChQ6moqCjRXbFE8D3rKO8n4+F6/hyn6ntqRx0tbqR67EDcSD0pmbBkZGTQuHHj1PB6kHniprqdn5+fwJ5Zo66ujo4ePaou+U1Fl112mQo+zd/PmpoadbVQKr6frLS0VJ3DkqrvqR11tLiR6rEDcSP1pOwhIb40ceHChXTttdfS+PHj6fnnn1eXwN1+++2J7tpFW7ZsmarpwYeBTp48qS7B5D3D+fPnJ7prFxU4m+/l8bH1gwcPUo8ePWjAgAF077330uOPP05XXHGFCkQPPfSQOnFw9uzZCex1bJ4fL4888gjdcsstKjHjL5Cf/exndPnll6tLaiF+UjlupGLsQNx4pGPFjURfpmSlX/3qV4b5ZWeYe07qcsW9e/cmuksxMW/ePMPcI1LPq1+/fuq2+WWY6G5dlA8++EBdlthy4ct9g5c2m0mKYZ4AqS5nNo/PGocPH05wr2Pz/BoaGozp06cbvXr1MtLT0w3zy8RYvHixYZ40mOhud0ipGjdSMXYgbkzvUHHDwf8kKlkCAAAA6LDnsAAAAEBqQcICAAAAtoeEBQAAAGwPCQsAAADYHhIWAAAAsD0kLAAAAGB7SFgAAADA9pCwAAAAgO0hYQEAAADbQ8ICAAAAtoeEBQAAAGzv/wFc4OkI3+dhfwAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.subplot(1, 2, 1)\n",
"plt.title('permeability')\n",
"plt.imshow(k_train.squeeze(-1)[0])\n",
"plt.subplot(1, 2, 2)\n",
"plt.title('field solution')\n",
"plt.imshow(u_train.squeeze(-1)[0])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "89a77ff1",
"metadata": {},
"source": [
"We now create the neural operator class. It is a very simple class, inheriting from `AbstractProblem`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8b27d283",
"metadata": {},
"outputs": [],
"source": [
"class NeuralOperatorSolver(AbstractProblem):\n",
" input_variables = ['u_0']\n",
" output_variables = ['u']\n",
" conditions = {'data' : Condition(input_points=LabelTensor(k_train, input_variables), \n",
" output_points=LabelTensor(u_train, input_variables))}\n",
"\n",
"# make problem\n",
"problem = NeuralOperatorSolver()"
]
},
{
"cell_type": "markdown",
"id": "1096cc20",
"metadata": {},
"source": [
"## Solving the problem with a FeedForward Neural Network\n",
"\n",
"We will first solve the problem using a Feedforward neural network. We will use the `SupervisedSolver` for solving the problem, since we are training using supervised learning."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e34f18b0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/u/n/ndemo/.local/lib/python3.9/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n",
" warnings.warn(\"Can't initialize NVML\")\n",
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"Missing logger folder: /u/n/ndemo/PINA/tutorials/tutorial5/lightning_logs\n",
"2023-10-17 10:41:03.316644: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2023-10-17 10:41:03.333768: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
"2023-10-17 10:41:03.383188: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2023-10-17 10:41:07.712785: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------------\n",
"0 | _loss | MSELoss | 0 \n",
"1 | _neural_net | Network | 481 \n",
"----------------------------------------\n",
"481 Trainable params\n",
"0 Non-trainable params\n",
"481 Total params\n",
"0.002 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb573678e5d94f0490ce09817a06f5cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/u/n/ndemo/.local/lib/python3.9/site-packages/torch/_tensor.py:1386: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3614.)\n",
" ret = func(*args, **kwargs)\n",
"`Trainer.fit` stopped: `max_epochs=100` reached.\n"
]
}
],
"source": [
"# make model\n",
"model=FeedForward(input_dimensions=1, output_dimensions=1)\n",
"\n",
"\n",
"# make solver\n",
"solver = SupervisedSolver(problem=problem, model=model)\n",
"\n",
"# make the trainer and train\n",
"trainer = Trainer(solver=solver, max_epochs=100)\n",
"trainer.train()\n"
]
},
{
"cell_type": "markdown",
"id": "7b2c35be",
"metadata": {},
"source": [
"The final loss is pretty high... We can calculate the error by importing `LpLoss`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0e2a6aa4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final error training 56.86%\n",
"Final error testing 56.82%\n"
]
}
],
"source": [
"from pina.loss import LpLoss\n",
"\n",
"# make the metric\n",
"metric_err = LpLoss(relative=True)\n",
"\n",
"\n",
"err = float(metric_err(u_train.squeeze(-1), solver.models[0](k_train).squeeze(-1)).mean())*100\n",
"print(f'Final error training {err:.2f}%')\n",
"\n",
"err = float(metric_err(u_test.squeeze(-1), solver.models[0](k_test).squeeze(-1)).mean())*100\n",
"print(f'Final error testing {err:.2f}%')"
]
},
{
"cell_type": "markdown",
"id": "6b5e5aa6",
"metadata": {},
"source": [
"## Solving the problem with a Fuorier Neural Operator (FNO)\n",
"\n",
"We will now move to solve the problem using a FNO. Since we are learning operator this approach is better suited, as we shall see."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9af523a5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------------\n",
"0 | _loss | MSELoss | 0 \n",
"1 | _neural_net | Network | 591 K \n",
"----------------------------------------\n",
"591 K Trainable params\n",
"0 Non-trainable params\n",
"591 K Total params\n",
"2.364 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0f7225d39f7241e692c6027c72adfd5f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=20` reached.\n"
]
}
],
"source": [
"# make model\n",
"lifting_net = torch.nn.Linear(1, 24)\n",
"projecting_net = torch.nn.Linear(24, 1)\n",
"model = FNO(lifting_net=lifting_net,\n",
" projecting_net=projecting_net,\n",
" n_modes=16,\n",
" dimensions=2,\n",
" inner_size=24,\n",
" padding=11)\n",
"\n",
"\n",
"# make solver\n",
"solver = SupervisedSolver(problem=problem, model=model)\n",
"\n",
"# make the trainer and train\n",
"trainer = Trainer(solver=solver, max_epochs=20)\n",
"trainer.train()\n"
]
},
{
"cell_type": "markdown",
"id": "84964cb9",
"metadata": {},
"source": [
"We can clearly see that with 1/3 of the total epochs the loss is lower. Let's see in testing.. Notice that the number of parameters is way higher than a `FeedForward` network. We suggest to use GPU or TPU for a speed up in training."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "58e2db89",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final error training 26.19%\n",
"Final error testing 25.89%\n"
]
}
],
"source": [
"err = float(metric_err(u_train.squeeze(-1), solver.models[0](k_train).squeeze(-1)).mean())*100\n",
"print(f'Final error training {err:.2f}%')\n",
"\n",
"err = float(metric_err(u_test.squeeze(-1), solver.models[0](k_test).squeeze(-1)).mean())*100\n",
"print(f'Final error testing {err:.2f}%')"
]
},
{
"cell_type": "markdown",
"id": "26e3a6e4",
"metadata": {},
"source": [
"As we can see the loss is way lower!"
]
},
{
"cell_type": "markdown",
"id": "ba1dfa4b",
"metadata": {},
"source": [
"## What's next?\n",
"\n",
"We have made a very simple example on how to use the `FNO` for learning neural operator. Currently in **PINA** we implement 1D/2D/3D cases. We suggest to extend the tutorial using more complex problems and train for longer, to see the full potential of neural operators."
]
}
],
"metadata": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}