{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "6f71ca5c", "metadata": {}, "source": [ "# Tutorial: Physics Informed Neural Networks on PINA\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial1/tutorial.ipynb)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ef4949c9", "metadata": {}, "source": [ "In this tutorial, we will demonstrate a typical use case of **PINA** on a toy problem, following the standard API procedure. \n", "\n", "

\n", " \"PINA\n", "

\n", "\n", "Specifically, the tutorial aims to introduce the following topics:\n", "\n", "* Explaining how to build **PINA** Problems,\n", "* Showing how to generate data for `PINN` training\n", "\n", "These are the two main steps needed **before** starting the modelling optimization (choose model and solver, and train). We will show each step in detail, and at the end, we will solve a simple Ordinary Differential Equation (ODE) problem using the `PINN` solver." ] }, { "attachments": {}, "cell_type": "markdown", "id": "cf9c96e3", "metadata": {}, "source": [ "## Build a PINA problem" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8a819659", "metadata": {}, "source": [ "Problem definition in the **PINA** framework is done by building a python `class`, which inherits from one or more problem classes (`SpatialProblem`, `TimeDependentProblem`, `ParametricProblem`, ...) depending on the nature of the problem. Below is an example:\n", "### Simple Ordinary Differential Equation\n", "Consider the following:\n", "\n", "$$\n", "\\begin{equation}\n", "\\begin{cases}\n", "\\frac{d}{dx}u(x) &= u(x) \\quad x\\in(0,1)\\\\\n", "u(x=0) &= 1 \\\\\n", "\\end{cases}\n", "\\end{equation}\n", "$$\n", "\n", "with the analytical solution $u(x) = e^x$. In this case, our ODE depends only on the spatial variable $x\\in(0,1)$ , meaning that our `Problem` class is going to be inherited from the `SpatialProblem` class:\n", "\n", "```python\n", "from pina.problem import SpatialProblem\n", "from pina.domain import CartesianProblem\n", "\n", "class SimpleODE(SpatialProblem):\n", " \n", " output_variables = ['u']\n", " spatial_domain = CartesianProblem({'x': [0, 1]})\n", "\n", " # other stuff ...\n", "```\n", "\n", "Notice that we define `output_variables` as a list of symbols, indicating the output variables of our equation (in this case only $u$), this is done because in **PINA** the `torch.Tensor`s are labelled, allowing the user maximal flexibility for the manipulation of the tensor. The `spatial_domain` variable indicates where the sample points are going to be sampled in the domain, in this case $x\\in[0,1]$.\n", "\n", "What if our equation is also time-dependent? In this case, our `class` will inherit from both `SpatialProblem` and `TimeDependentProblem`:\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "2373a925", "metadata": {}, "outputs": [], "source": [ "## routine needed to run the notebook on Google Colab\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "if IN_COLAB:\n", " !pip install \"pina-mathlab\"\n", "\n", "from pina.problem import SpatialProblem, TimeDependentProblem\n", "from pina.domain import CartesianDomain\n", "\n", "class TimeSpaceODE(SpatialProblem, TimeDependentProblem):\n", " \n", " output_variables = ['u']\n", " spatial_domain = CartesianDomain({'x': [0, 1]})\n", " temporal_domain = CartesianDomain({'t': [0, 1]})\n", "\n", " # other stuff ..." ] }, { "attachments": {}, "cell_type": "markdown", "id": "ad8566b8", "metadata": {}, "source": [ "where we have included the `temporal_domain` variable, indicating the time domain wanted for the solution.\n", "\n", "In summary, using **PINA**, we can initialize a problem with a class which inherits from different base classes: `SpatialProblem`, `TimeDependentProblem`, `ParametricProblem`, and so on depending on the type of problem we are considering. Here are some examples (more on the official documentation):\n", "* ``SpatialProblem`` $\\rightarrow$ a differential equation with spatial variable(s) ``spatial_domain``\n", "* ``TimeDependentProblem`` $\\rightarrow$ a time-dependent differential equation with temporal variable(s) ``temporal_domain``\n", "* ``ParametricProblem`` $\\rightarrow$ a parametrized differential equation with parametric variable(s) ``parameter_domain``\n", "* ``AbstractProblem`` $\\rightarrow$ any **PINA** problem inherits from here" ] }, { "attachments": {}, "cell_type": "markdown", "id": "592a4c43", "metadata": {}, "source": [ "### Write the problem class\n", "\n", "Once the `Problem` class is initialized, we need to represent the differential equation in **PINA**. In order to do this, we need to load the **PINA** operators from `pina.operator` module. Again, we'll consider Equation (1) and represent it in **PINA**:" ] }, { "cell_type": "code", "execution_count": null, "id": "f2608e2e", "metadata": {}, "outputs": [], "source": [ "from pina.problem import SpatialProblem\n", "from pina.operator import grad\n", "from pina import Condition\n", "from pina.domain import CartesianDomain\n", "from pina.equation import Equation, FixedValue\n", "\n", "import torch\n", "import matplotlib.pyplot as plt\n", "plt.style.use('tableau-colorblind10')\n", "\n", "class SimpleODE(SpatialProblem):\n", "\n", " output_variables = ['u']\n", " spatial_domain = CartesianDomain({'x': [0, 1]})\n", "\n", " domains ={\n", " 'x0': CartesianDomain({'x': 0.}),\n", " 'D': CartesianDomain({'x': [0, 1]})\n", " }\n", "\n", " # defining the ode equation\n", " def ode_equation(input_, output_):\n", "\n", " # computing the derivative\n", " u_x = grad(output_, input_, components=['u'], d=['x'])\n", "\n", " # extracting the u input variable\n", " u = output_.extract(['u'])\n", "\n", " # calculate the residual and return it\n", " return u_x - u\n", "\n", " # conditions to hold\n", " conditions = {\n", " 'bound_cond': Condition(domain='x0', equation=FixedValue(1.)),\n", " 'phys_cond': Condition(domain='D', equation=Equation(ode_equation))\n", " }\n", "\n", " # defining the true solution\n", " def truth_solution(self, pts):\n", " return torch.exp(pts.extract(['x']))\n", " \n", "problem = SimpleODE()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7cf64d01", "metadata": {}, "source": [ "After we define the `Problem` class, we need to write different class methods, where each method is a function returning a residual. These functions are the ones minimized during PINN optimization, given the initial conditions. For example, in the domain $[0,1]$, the ODE equation (`ode_equation`) must be satisfied. We represent this by returning the difference between subtracting the variable `u` from its gradient (the residual), which we hope to minimize to 0. This is done for all conditions. Notice that we do not pass directly a `python` function, but an `Equation` object, which is initialized with the `python` function. This is done so that all the computations and internal checks are done inside **PINA**.\n", "\n", "Once we have defined the function, we need to tell the neural network where these methods are to be applied. To do so, we use the `Condition` class. In the `Condition` class, we pass the location points and the equation we want minimized on those points (other possibilities are allowed, see the documentation for reference).\n", "\n", "Finally, it's possible to define a `truth_solution` function, which can be useful if we want to plot the results and see how the real solution compares to the expected (true) solution. Notice that the `truth_solution` function is a method of the `PINN` class, but it is not mandatory for problem definition.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "78b30f95", "metadata": {}, "source": [ "## Generate data \n", "\n", "Data for training can come in form of direct numerical simulation results, or points in the domains. In case we perform unsupervised learning, we just need the collocation points for training, i.e. points where we want to evaluate the neural network. Sampling point in **PINA** is very easy, here we show three examples using the `.discretise_domain` method of the `AbstractProblem` class." ] }, { "cell_type": "code", "execution_count": 3, "id": "09ce5c3a", "metadata": {}, "outputs": [], "source": [ "# sampling 20 points in [0, 1] through discretization in all locations\n", "problem.discretise_domain(n=20, mode='grid', domains='all')\n", "\n", "# sampling 20 points in (0, 1) through latin hypercube sampling in D, and 1 point in x0\n", "problem.discretise_domain(n=20, mode='latin', domains=['D'])\n", "problem.discretise_domain(n=1, mode='random', domains=['x0'])\n", "\n", "# sampling 20 points in (0, 1) randomly\n", "problem.discretise_domain(n=20, mode='random')" ] }, { "cell_type": "markdown", "id": "8fbb679f", "metadata": {}, "source": [ "We are going to use latin hypercube points for sampling. We need to sample in all the conditions domains. In our case we sample in `D` and `x0`." ] }, { "cell_type": "code", "execution_count": 4, "id": "329962b6", "metadata": {}, "outputs": [], "source": [ "# sampling for training\n", "problem.discretise_domain(1, 'random', domains=['x0']) # TODO check\n", "problem.discretise_domain(20, 'lh', domains=['D'])" ] }, { "cell_type": "markdown", "id": "ca2ac5c2", "metadata": {}, "source": [ "The points are saved in a python `dict`, and can be accessed by calling the attribute `input_pts` of the problem " ] }, { "cell_type": "code", "execution_count": 5, "id": "d6ed9aaf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input points: {'x0': LabelTensor([[0.]]), 'D': LabelTensor([[0.4519],\n", " [0.4306],\n", " [0.8085],\n", " [0.6035],\n", " [0.8842],\n", " [0.7970],\n", " [0.3849],\n", " [0.1173],\n", " [0.7432],\n", " [0.0200],\n", " [0.5698],\n", " [0.9792],\n", " [0.5295],\n", " [0.3197],\n", " [0.0558],\n", " [0.2836],\n", " [0.1626],\n", " [0.2333],\n", " [0.6633],\n", " [0.9157]])}\n", "Input points labels: ['x']\n" ] } ], "source": [ "print('Input points:', problem.discretised_domains)\n", "print('Input points labels:', problem.discretised_domains['D'].labels)" ] }, { "cell_type": "markdown", "id": "669e8534", "metadata": {}, "source": [ "To visualize the sampled points we can use `matplotlib.pyplot`:" ] }, { "cell_type": "code", "execution_count": null, "id": "3802e22a", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGdCAYAAADuR1K7AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIfBJREFUeJzt3XtwVPX9//HXJiEbVJLILSEapFAVFJQxmBDU4VvJNBZHZcQRkSLSVGoF6o8gyk3SajXUKygoo62ljlAoFqlgJhaDd6JggMq9WhAQ3ABFNsglCcnn9wfD2pUA2ZDdZN88HzM7DCefc87nfLKyzznZjR7nnBMAAIARMU09AQAAgMZE3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMCUuKaeQFOora3Vrl271KpVK3k8nqaeDgAAqAfnnA4cOKC0tDTFxJz8/sxZGTe7du1Senp6U08DAAA0wI4dO3ThhRee9OtnZdy0atVK0rHFSUxMbOLZAACA+qioqFB6enrgdfxkzsq4Of6jqMTEROIGAIAoc7q3lPCGYgAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKRGJm5kzZ6pTp05KSEhQVlaWVqxYccrxCxYsUNeuXZWQkKAePXqoqKjopGPvvfdeeTweTZs2rZFnDQAAolHY42b+/PnKz89XQUGBVq1apSuvvFK5ubnavXt3neOXL1+uwYMHKy8vT6tXr9aAAQM0YMAArVu37oSxb7zxhj755BOlpaWF+zIAAECUCHvcPPPMM7rnnns0fPhwXXbZZZo1a5bOOeccvfLKK3WOnz59um644QaNGzdO3bp106OPPqqrrrpKM2bMCBq3c+dOjR49WnPmzFGLFi3CfRkAACBKhDVuqqqqVFZWppycnO9PGBOjnJwclZaW1rlPaWlp0HhJys3NDRpfW1uroUOHaty4cbr88stPO4/KykpVVFQEPQAAgE1hjZu9e/eqpqZGKSkpQdtTUlLk8/nq3Mfn8512/B/+8AfFxcXpN7/5Tb3mUVhYqKSkpMAjPT09xCsBAADRIuo+LVVWVqbp06dr9uzZ8ng89dpnwoQJ8vv9gceOHTvCPEsAANBUwho3bdu2VWxsrMrLy4O2l5eXKzU1tc59UlNTTzn+ww8/1O7du9WxY0fFxcUpLi5O27Zt09ixY9WpU6c6j+n1epWYmBj0AAAANoU1buLj45WRkaGSkpLAttraWpWUlCg7O7vOfbKzs4PGS9LSpUsD44cOHarPP/9ca9asCTzS0tI0btw4vf322+G7GAAAEBXiwn2C/Px8DRs2TL169VJmZqamTZumgwcPavjw4ZKku+66SxdccIEKCwslSffff7/69u2rp59+WjfeeKPmzZunzz77TC+99JIkqU2bNmrTpk3QOVq0aKHU1FRdeuml4b4cAADQzIU9bgYNGqQ9e/ZoypQp8vl86tmzp4qLiwNvGt6+fbtiYr6/gdSnTx/NnTtXkydP1sSJE3XxxRdr0aJF6t69e7inCgAADPA451xTTyLSKioqlJSUJL/fz/tvAACIEvV9/Y66T0sBAACcCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAUyISNzNnzlSnTp2UkJCgrKwsrVix4pTjFyxYoK5duyohIUE9evRQUVFR4GvV1dV66KGH1KNHD5177rlKS0vTXXfdpV27doX7MgAAQBQIe9zMnz9f+fn5Kigo0KpVq3TllVcqNzdXu3fvrnP88uXLNXjwYOXl5Wn16tUaMGCABgwYoHXr1kmSDh06pFWrVunhhx/WqlWrtHDhQm3evFk333xzuC8FAABEAY9zzoXzBFlZWbr66qs1Y8YMSVJtba3S09M1evRojR8//oTxgwYN0sGDB7VkyZLAtt69e6tnz56aNWtWnedYuXKlMjMztW3bNnXs2PG0c6qoqFBSUpL8fr8SExMbeGUAACCS6vv6HdY7N1VVVSorK1NOTs73J4yJUU5OjkpLS+vcp7S0NGi8JOXm5p50vCT5/X55PB4lJyfX+fXKykpVVFQEPQAAgE1hjZu9e/eqpqZGKSkpQdtTUlLk8/nq3Mfn84U0/siRI3rooYc0ePDgk1ZcYWGhkpKSAo/09PQGXA0AAIgGUf1pqerqat1+++1yzunFF1886bgJEybI7/cHHjt27IjgLAEAQCTFhfPgbdu2VWxsrMrLy4O2l5eXKzU1tc59UlNT6zX+eNhs27ZNy5YtO+XP3rxer7xebwOvAgAARJOw3rmJj49XRkaGSkpKAttqa2tVUlKi7OzsOvfJzs4OGi9JS5cuDRp/PGy++OILvfPOO2rTpk14LgAAAESdsN65kaT8/HwNGzZMvXr1UmZmpqZNm6aDBw9q+PDhkqS77rpLF1xwgQoLCyVJ999/v/r27aunn35aN954o+bNm6fPPvtML730kqRjYXPbbbdp1apVWrJkiWpqagLvx2ndurXi4+PDfUkAAKAZC3vcDBo0SHv27NGUKVPk8/nUs2dPFRcXB940vH37dsXEfH8DqU+fPpo7d64mT56siRMn6uKLL9aiRYvUvXt3SdLOnTv15ptvSpJ69uwZdK53331X//d//xfuSwIAAM1Y2H/PTXPE77kBACD6NIvfcwMAABBpxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMiUjczJw5U506dVJCQoKysrK0YsWKU45fsGCBunbtqoSEBPXo0UNFRUVBX3fOacqUKerQoYNatmypnJwcffHFF+G8BAAAECXCHjfz589Xfn6+CgoKtGrVKl155ZXKzc3V7t276xy/fPlyDR48WHl5eVq9erUGDBigAQMGaN26dYExTzzxhJ577jnNmjVLn376qc4991zl5ubqyJEj4b6c0/rGf1jL/7NX3/gPn3yQf6e09YNjfzbEme7fHM7ZFNcQyTk0xrGbao3Ccd76HrM5PC/CMZczOVak16Q5fQ+Oi8Sc6jpHc1qLhs7FwutFA3iccy6cJ8jKytLVV1+tGTNmSJJqa2uVnp6u0aNHa/z48SeMHzRokA4ePKglS5YEtvXu3Vs9e/bUrFmz5JxTWlqaxo4dqwceeECS5Pf7lZKSotmzZ+uOO+447ZwqKiqUlJQkv9+vxMTERrpSaf7K7ZqwcK1qnRTjkQpv7aFBV3cMHrTqVWnx/ZKrlTwx0k3Tpavuqv9JznT/hmjsczbFNURyDo1x7KZao3Cct77HbA7Pi3DM5UyOFek1aU7fg0jOqa5zSM1nLRq6BhZeL36gvq/fYb1zU1VVpbKyMuXk5Hx/wpgY5eTkqLS0tM59SktLg8ZLUm5ubmD81q1b5fP5gsYkJSUpKyvrpMesrKxURUVF0KOxfeM/HAgbSap10sSF64Lv4Ph3fv9Nl479ufj/1b9uz3T/hmjsczbFNURyDo1x7KZao3Cct77HbA7Pi3DM5UyOFek1aU7fg0jOqc5z3C+9+ZvmsRYNXQMLrxdnIKxxs3fvXtXU1CglJSVoe0pKinw+X537+Hy+U44//mcoxywsLFRSUlLgkZ6e3qDrOZWtew8Gwua4Guf01d5D32/Y95/vv+nHuRpp35b6neRM92+Ixj5nU1xDJOfQGMduqjUKx3nre8zm8LwIx1zO5FiRXpPm9D04LhJzqvMctZJ+8A96tD0fLbxenIGz4tNSEyZMkN/vDzx27NjR6Of4UdtzFeMJ3hbr8ahT23O+39C6y7HbdP/LEyu17ly/k5zp/g3R2OdsimuI5Bwa49hNtUbhOG99j9kcnhfhmMuZHCvSa9KcvgfHRWJOdZ4jRtIP/kGPtuejhdeLMxDWuGnbtq1iY2NVXl4etL28vFypqal17pOamnrK8cf/DOWYXq9XiYmJQY/G1iGppQpv7aFYz7H/IGI9Hj1+a3d1SGr5/aCkC479/NETe+zvnljppmnHttfHme7fEI19zqa4hkjOoTGO3VRrFI7z1veYzeF5EY65nMmxIr0mzel7EMk51XmO6dLNzzWPtWjoGlh4vTgDEXlDcWZmpp5//nlJx95Q3LFjR40aNeqkbyg+dOiQFi9eHNjWp08fXXHFFUFvKH7ggQc0duxYScfeYNS+ffsmf0OxdOy9N1/tPaRObc8JDpv/5d957DZd684N+6af6f4N0djnbIpriOQcGuPYTbVG4ThvfY/ZHJ4X4ZjLmRwr0mvSnL4Hx0ViTnWdozmtRUPnYuH14n/U+/Xbhdm8efOc1+t1s2fPdhs2bHAjRoxwycnJzufzOeecGzp0qBs/fnxg/Mcff+zi4uLcU0895TZu3OgKCgpcixYt3Nq1awNjpk6d6pKTk90//vEP9/nnn7tbbrnF/ehHP3KHDx+u15z8fr+T5Px+f+NeLAAACJv6vn7HNWpS1WHQoEHas2ePpkyZIp/Pp549e6q4uDjwhuDt27crJub7n4716dNHc+fO1eTJkzVx4kRdfPHFWrRokbp37x4Y8+CDD+rgwYMaMWKE9u/fr2uvvVbFxcVKSEgI9+UAAIBmLuw/lmqOwvljKQAAEB7N4vfcAAAARBpxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFPCFjf79u3TkCFDlJiYqOTkZOXl5em777475T5HjhzRyJEj1aZNG5133nkaOHCgysvLA1//17/+pcGDBys9PV0tW7ZUt27dNH369HBdAgAAiEJhi5shQ4Zo/fr1Wrp0qZYsWaIPPvhAI0aMOOU+Y8aM0eLFi7VgwQK9//772rVrl2699dbA18vKytS+fXu99tprWr9+vSZNmqQJEyZoxowZ4boMAAAQZTzOOdfYB924caMuu+wyrVy5Ur169ZIkFRcXq3///vr666+VlpZ2wj5+v1/t2rXT3Llzddttt0mSNm3apG7duqm0tFS9e/eu81wjR47Uxo0btWzZsnrPr6KiQklJSfL7/UpMTGzAFQIAgEir7+t3WO7clJaWKjk5ORA2kpSTk6OYmBh9+umnde5TVlam6upq5eTkBLZ17dpVHTt2VGlp6UnP5ff71bp168abPAAAiGpx4Tioz+dT+/btg08UF6fWrVvL5/OddJ/4+HglJycHbU9JSTnpPsuXL9f8+fP11ltvnXI+lZWVqqysDPy9oqKiHlcBAACiUUh3bsaPHy+Px3PKx6ZNm8I11yDr1q3TLbfcooKCAv30pz895djCwkIlJSUFHunp6RGZIwAAiLyQ7tyMHTtWd9999ynHdO7cWampqdq9e3fQ9qNHj2rfvn1KTU2tc7/U1FRVVVVp//79QXdvysvLT9hnw4YN6tevn0aMGKHJkyefdt4TJkxQfn5+4O8VFRUEDgAARoUUN+3atVO7du1OOy47O1v79+9XWVmZMjIyJEnLli1TbW2tsrKy6twnIyNDLVq0UElJiQYOHChJ2rx5s7Zv367s7OzAuPXr1+v666/XsGHD9Nhjj9Vr3l6vV16vt15jAQBAdAvLp6Uk6Wc/+5nKy8s1a9YsVVdXa/jw4erVq5fmzp0rSdq5c6f69eunV199VZmZmZKkX//61yoqKtLs2bOVmJio0aNHSzr23hrp2I+irr/+euXm5urJJ58MnCs2NrZe0XUcn5YCACD61Pf1OyxvKJakOXPmaNSoUerXr59iYmI0cOBAPffcc4GvV1dXa/PmzTp06FBg27PPPhsYW1lZqdzcXL3wwguBr7/++uvas2ePXnvtNb322muB7RdddJG++uqrcF0KAACIImG7c9OccecGAIDo06S/5wYAAKCpEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmEDcAAMAU4gYAAJhC3AAAAFOIGwAAYApxAwAATCFuAACAKcQNAAAwhbgBAACmhC1u9u3bpyFDhigxMVHJycnKy8vTd999d8p9jhw5opEjR6pNmzY677zzNHDgQJWXl9c59r///a8uvPBCeTwe7d+/PwxXAAAAolHY4mbIkCFav369li5dqiVLluiDDz7QiBEjTrnPmDFjtHjxYi1YsEDvv/++du3apVtvvbXOsXl5ebriiivCMXUAABDFPM4519gH3bhxoy677DKtXLlSvXr1kiQVFxerf//++vrrr5WWlnbCPn6/X+3atdPcuXN12223SZI2bdqkbt26qbS0VL179w6MffHFFzV//nxNmTJF/fr107fffqvk5OR6z6+iokJJSUny+/1KTEw8s4sFAAARUd/X77DcuSktLVVycnIgbCQpJydHMTEx+vTTT+vcp6ysTNXV1crJyQls69q1qzp27KjS0tLAtg0bNuiRRx7Rq6++qpiY+k2/srJSFRUVQQ8AAGBTWOLG5/Opffv2Qdvi4uLUunVr+Xy+k+4THx9/wh2YlJSUwD6VlZUaPHiwnnzySXXs2LHe8yksLFRSUlLgkZ6eHtoFAQCAqBFS3IwfP14ej+eUj02bNoVrrpowYYK6deumn//85yHv5/f7A48dO3aEaYYAAKCpxYUyeOzYsbr77rtPOaZz585KTU3V7t27g7YfPXpU+/btU2pqap37paamqqqqSvv37w+6e1NeXh7YZ9myZVq7dq1ef/11SdLxtwu1bdtWkyZN0u9+97s6j+31euX1eutziQAAIMqFFDft2rVTu3btTjsuOztb+/fvV1lZmTIyMiQdC5Pa2lplZWXVuU9GRoZatGihkpISDRw4UJK0efNmbd++XdnZ2ZKkv//97zp8+HBgn5UrV+oXv/iFPvzwQ3Xp0iWUSwEAAEaFFDf11a1bN91www265557NGvWLFVXV2vUqFG64447Ap+U2rlzp/r166dXX31VmZmZSkpKUl5envLz89W6dWslJiZq9OjRys7ODnxS6ocBs3fv3sD5Qvm0FAAAsCsscSNJc+bM0ahRo9SvXz/FxMRo4MCBeu655wJfr66u1ubNm3Xo0KHAtmeffTYwtrKyUrm5uXrhhRfCNUUAAGBQWH7PTXPH77kBACD6NOnvuQEAAGgqxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADAlrqkn0BScc5KkioqKJp4JAACor+Ov28dfx0/mrIybAwcOSJLS09ObeCYAACBUBw4cUFJS0km/7nGnyx+DamtrtWvXLrVq1Uoej6dRj11RUaH09HTt2LFDiYmJjXpsnIj1jizWO7JY78hivSOrIevtnNOBAweUlpammJiTv7PmrLxzExMTowsvvDCs50hMTOQ/jghivSOL9Y4s1juyWO/ICnW9T3XH5jjeUAwAAEwhbgAAgCnETSPzer0qKCiQ1+tt6qmcFVjvyGK9I4v1jizWO7LCud5n5RuKAQCAXdy5AQAAphA3AADAFOIGAACYQtwAAABTiJsQzZw5U506dVJCQoKysrK0YsWKU45fsGCBunbtqoSEBPXo0UNFRUURmqkdoaz5yy+/rOuuu07nn3++zj//fOXk5Jz2e4RgoT7Hj5s3b548Ho8GDBgQ3gkaE+p679+/XyNHjlSHDh3k9Xp1ySWX8O9KCEJd72nTpunSSy9Vy5YtlZ6erjFjxujIkSMRmm10++CDD3TTTTcpLS1NHo9HixYtOu0+7733nq666ip5vV79+Mc/1uzZsxt2cod6mzdvnouPj3evvPKKW79+vbvnnntccnKyKy8vr3P8xx9/7GJjY90TTzzhNmzY4CZPnuxatGjh1q5dG+GZR69Q1/zOO+90M2fOdKtXr3YbN250d999t0tKSnJff/11hGcenUJd7+O2bt3qLrjgAnfddde5W265JTKTNSDU9a6srHS9evVy/fv3dx999JHbunWre++999yaNWsiPPPoFOp6z5kzx3m9Xjdnzhy3detW9/bbb7sOHTq4MWPGRHjm0amoqMhNmjTJLVy40Elyb7zxxinHb9myxZ1zzjkuPz/fbdiwwT3//PMuNjbWFRcXh3xu4iYEmZmZbuTIkYG/19TUuLS0NFdYWFjn+Ntvv93deOONQduysrLcr371q7DO05JQ1/yHjh496lq1auX+8pe/hGuKpjRkvY8ePer69Onj/vjHP7phw4YRNyEIdb1ffPFF17lzZ1dVVRWpKZoS6nqPHDnSXX/99UHb8vPz3TXXXBPWeVpUn7h58MEH3eWXXx60bdCgQS43Nzfk8/FjqXqqqqpSWVmZcnJyAttiYmKUk5Oj0tLSOvcpLS0NGi9Jubm5Jx2PYA1Z8x86dOiQqqur1bp163BN04yGrvcjjzyi9u3bKy8vLxLTNKMh6/3mm28qOztbI0eOVEpKirp3767HH39cNTU1kZp21GrIevfp00dlZWWBH11t2bJFRUVF6t+/f0TmfLZpzNfMs/J/nNkQe/fuVU1NjVJSUoK2p6SkaNOmTXXu4/P56hzv8/nCNk9LGrLmP/TQQw8pLS3thP9gcKKGrPdHH32kP/3pT1qzZk0EZmhLQ9Z7y5YtWrZsmYYMGaKioiJ9+eWXuu+++1RdXa2CgoJITDtqNWS977zzTu3du1fXXnutnHM6evSo7r33Xk2cODESUz7rnOw1s6KiQocPH1bLli3rfSzu3MCsqVOnat68eXrjjTeUkJDQ1NMx58CBAxo6dKhefvlltW3btqmnc1aora1V+/bt9dJLLykjI0ODBg3SpEmTNGvWrKaemknvvfeeHn/8cb3wwgtatWqVFi5cqLfeekuPPvpoU08Np8Gdm3pq27atYmNjVV5eHrS9vLxcqampde6Tmpoa0ngEa8iaH/fUU09p6tSpeuedd3TFFVeEc5pmhLre//nPf/TVV1/ppptuCmyrra2VJMXFxWnz5s3q0qVLeCcdxRry/O7QoYNatGih2NjYwLZu3brJ5/OpqqpK8fHxYZ1zNGvIej/88MMaOnSofvnLX0qSevTooYMHD2rEiBGaNGmSYmK4P9CYTvaamZiYGNJdG4k7N/UWHx+vjIwMlZSUBLbV1taqpKRE2dnZde6TnZ0dNF6Sli5detLxCNaQNZekJ554Qo8++qiKi4vVq1evSEzVhFDXu2vXrlq7dq3WrFkTeNx88836yU9+ojVr1ig9PT2S0486DXl+X3PNNfryyy8DESlJ//73v9WhQwfC5jQast6HDh06IWCOh6Xjf8vY6Br1NTPktyCfxebNm+e8Xq+bPXu227BhgxsxYoRLTk52Pp/POefc0KFD3fjx4wPjP/74YxcXF+eeeuopt3HjRldQUMBHwUMU6ppPnTrVxcfHu9dff9198803gceBAwea6hKiSqjr/UN8Wio0oa739u3bXatWrdyoUaPc5s2b3ZIlS1z79u3d73//+6a6hKgS6noXFBS4Vq1aub/+9a9uy5Yt7p///Kfr0qWLu/3225vqEqLKgQMH3OrVq93q1audJPfMM8+41atXu23btjnnnBs/frwbOnRoYPzxj4KPGzfObdy40c2cOZOPgkfK888/7zp27Oji4+NdZmam++STTwJf69u3rxs2bFjQ+L/97W/ukksucfHx8e7yyy93b731VoRnHP1CWfOLLrrISTrhUVBQEPmJR6lQn+P/i7gJXajrvXz5cpeVleW8Xq/r3Lmze+yxx9zRo0cjPOvoFcp6V1dXu9/+9reuS5cuLiEhwaWnp7v77rvPffvtt5GfeBR699136/z3+PgaDxs2zPXt2/eEfXr27Oni4+Nd586d3Z///OcGndvjHPfWAACAHbznBgAAmELcAAAAU4gbAABgCnEDAABMIW4AAIApxA0AADCFuAEAAKYQNwAAwBTiBgAAmELcAAAAU4gbAABgCnEDAABM+f/F1ffAV2cNNQAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "variables=problem.spatial_variables\n", "fig = plt.figure()\n", "proj = \"3d\" if len(variables) == 3 else None\n", "ax = fig.add_subplot(projection=proj)\n", "for location in problem.input_pts:\n", " coords = problem.input_pts[location].extract(variables).T.detach()\n", " ax.plot(coords.flatten(),torch.zeros(coords.flatten().shape),\".\",label=location)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "22e502dd", "metadata": {}, "source": [ "## Perform a small training" ] }, { "attachments": {}, "cell_type": "markdown", "id": "075f43f5", "metadata": {}, "source": [ "Once we have defined the problem and generated the data we can start the modelling. Here we will choose a `FeedForward` neural network available in `pina.model`, and we will train using the `PINN` solver from `pina.solver`. We highlight that this training is fairly simple, for more advanced stuff consider the tutorials in the ***Physics Informed Neural Networks*** section of ***Tutorials***. For training we use the `Trainer` class from `pina.trainer`. Here we show a very short training and some method for plotting the results. Notice that by default all relevant metrics (e.g. MSE error during training) are going to be tracked using a `lightning` logger, by default `CSVLogger`. If you want to track the metric by yourself without a logger, use `pina.callback.MetricTracker`." ] }, { "cell_type": "code", "execution_count": 7, "id": "3bb4dc9b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d2c3b03173424844beead0135687f8a1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: | | 0/? [00:00" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pts = pinn.problem.spatial_domain.sample(256, 'grid', variables='x')\n", "predicted_output = pinn.forward(pts).extract('u').as_subclass(torch.Tensor).cpu().detach()\n", "true_output = pinn.problem.truth_solution(pts).cpu().detach()\n", "pts = pts.cpu()\n", "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))\n", "ax.plot(pts.extract(['x']), predicted_output, label='Neural Network solution')\n", "ax.plot(pts.extract(['x']), true_output, label='True solution')\n", "plt.legend()" ] }, { "cell_type": "markdown", "id": "bf47b98a", "metadata": {}, "source": [ "The solution is overlapped with the actual one, and they are barely indistinguishable. We can also plot easily the loss:" ] }, { "cell_type": "code", "execution_count": 10, "id": "bf6211e6", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "list_ = [\n", " idx for idx, s in enumerate(trainer.callbacks)\n", " if isinstance(s, MetricTracker)\n", " ]\n", "trainer_metrics = trainer.callbacks[list_[0]].metrics\n", "\n", "loss = trainer_metrics['val_loss']\n", "epochs = range(len(loss))\n", "plt.plot(epochs, loss.cpu())\n", "# plotting\n", "plt.xlabel('epoch')\n", "plt.ylabel('loss')\n", "plt.yscale('log')" ] }, { "cell_type": "markdown", "id": "58172899", "metadata": {}, "source": [ "As we can see the loss has not reached a minimum, suggesting that we could train for longer" ] }, { "cell_type": "markdown", "id": "33e672da", "metadata": {}, "source": [ "## What's next?\n", "\n", "Congratulations on completing the introductory tutorial of **PINA**! There are several directions you can go now:\n", "\n", "1. Train the network for longer or with different layer sizes and assert the finaly accuracy\n", "\n", "2. Train the network using other types of models (see `pina.model`)\n", "\n", "3. GPU training and speed benchmarking\n", "\n", "4. Many more..." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }