{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial: Introduction to `Trainer` class\n", "[](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial11/tutorial.ipynb)\n", "\n", "In this tutorial, we will delve deeper into the functionality of the `Trainer` class, which serves as the cornerstone for training **PINA** [Solvers](https://mathlab.github.io/PINA/_rst/_code.html#solvers). \n", "\n", "The `Trainer` class offers a plethora of features aimed at improving model accuracy, reducing training time and memory usage, facilitating logging visualization, and more thanks to the amazing job done by the PyTorch Lightning team!\n", "\n", "Our leading example will revolve around solving a simple regression problem where we want to approximate the following function with a Neural Net model $\\mathcal{M}_{\\theta}$:\n", "$$y = x^3$$\n", "by having only a set of $20$ observations $\\{x_i, y_i\\}_{i=1}^{20}$, with $x_i \\sim\\mathcal{U}[-3, 3]\\;\\;\\forall i\\in(1,\\dots,20)$.\n", "\n", "Let's start by importing useful modules!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " import google.colab\n", "\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "if IN_COLAB:\n", " !pip install \"pina-mathlab[tutorial]\"\n", "\n", "import torch\n", "import warnings\n", "\n", "from pina import Trainer\n", "from pina.solver import SupervisedSolver\n", "from pina.model import FeedForward\n", "from pina.problem.zoo import SupervisedProblem\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define problem and solver." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# defining the problem\n", "x_train = torch.empty((20, 1)).uniform_(-3, 3)\n", "y_train = x_train.pow(3) + 3 * torch.randn_like(x_train)\n", "\n", "problem = SupervisedProblem(x_train, y_train)\n", "\n", "# build the model\n", "model = FeedForward(\n", " layers=[10, 10],\n", " func=torch.nn.Tanh,\n", " output_dimensions=1,\n", " input_dimensions=1,\n", ")\n", "\n", "# create the SupervisedSolver object\n", "solver = SupervisedSolver(problem, model, use_lt=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Till now we just followed the extact step of the previous tutorials. The `Trainer` object\n", "can be initialized by simiply passing the `SupervisedSolver` solver" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ "trainer = Trainer(solver=solver)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Trainer Accelerator\n", "\n", "When creating the `Trainer`, **by default** the most performing `accelerator` for training which is available in your system will be chosen, ranked as follows:\n", "1. [TPU](https://cloud.google.com/tpu/docs/intro-to-tpu)\n", "2. [IPU](https://www.graphcore.ai/products/ipu)\n", "3. [HPU](https://habana.ai/)\n", "4. [GPU](https://www.intel.com/content/www/us/en/products/docs/processors/what-is-a-gpu.html#:~:text=What%20does%20GPU%20stand%20for,video%20editing%2C%20and%20gaming%20applications) or [MPS](https://developer.apple.com/metal/pytorch/)\n", "5. CPU\n", "\n", "For setting manually the `accelerator` run:\n", "\n", "* `accelerator = {'gpu', 'cpu', 'hpu', 'mps', 'cpu', 'ipu'}` sets the accelerator to a specific one" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ "trainer = Trainer(solver=solver, accelerator=\"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, even if a `GPU` is available on the system, it is not used since we set `accelerator='cpu'`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Trainer Logging\n", "\n", "In **PINA** you can log metrics in different ways. The simplest approach is to use the `MetricTracker` class from `pina.callbacks`, as seen in the [*Introduction to Physics Informed Neural Networks training*](https://github.com/mathLab/PINA/blob/master/tutorials/tutorial1/tutorial.ipynb) tutorial.\n", "\n", "However, especially when we need to train multiple times to get an average of the loss across multiple runs, `lightning.pytorch.loggers` might be useful. Here we will use `TensorBoardLogger` (more on [logging](https://lightning.ai/docs/pytorch/stable/extensions/logging.html) here), but you can choose the one you prefer (or make your own one).\n", "\n", "We will now import `TensorBoardLogger`, do three runs of training, and then visualize the results. Notice we set `enable_model_summary=False` to avoid model summary specifications (e.g. number of parameters); set it to `True` if needed." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "HPU available: False, using: 0 HPUs\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "775a2d088e304b2589631b176c9e99e2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: | | 0/? [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=100` reached.\n", "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d858dc0a31214f5f86aae78823525b56", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: | | 0/? [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=100` reached.\n", "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "739bf2009f7a48a1b59b7df695276672", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: | | 0/? [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=100` reached.\n" ] } ], "source": [ "from lightning.pytorch.loggers import TensorBoardLogger\n", "\n", "# three run of training, by default it trains for 1000 epochs, we set the max to 100\n", "# we reinitialize the model each time otherwise the same parameters will be optimized\n", "for _ in range(3):\n", " model = FeedForward(\n", " layers=[10, 10],\n", " func=torch.nn.Tanh,\n", " output_dimensions=1,\n", " input_dimensions=1,\n", " )\n", " solver = SupervisedSolver(problem, model, use_lt=False)\n", " trainer = Trainer(\n", " solver=solver,\n", " accelerator=\"cpu\",\n", " logger=TensorBoardLogger(save_dir=\"training_log\"),\n", " enable_model_summary=False,\n", " train_size=1.0,\n", " val_size=0.0,\n", " test_size=0.0,\n", " max_epochs=100\n", " )\n", " trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now visualize the logs by simply running `tensorboard --logdir=training_log/` in the terminal. You should obtain a webpage similar to the one shown below if running for 1000 epochs:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n",
"
\n",
"