{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "6f71ca5c", "metadata": {}, "source": [ "# Tutorial: Introduction to Solver classes\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial18/tutorial.ipynb)\n", "\n", "In this tutorial, we will explore the Solver classes in PINA, that are the core components for optimizing models. Solvers are designed to manage and execute the optimization process, providing the flexibility to work with various types of neural networks and loss functions. We will show how to use this class to select and implement different solvers, such as Supervised Learning, Physics-Informed Neural Networks (PINNs), and Generative Learning solvers. By the end of this tutorial, you'll be equipped to easily choose and customize solvers for your own tasks, streamlining the model training process.\n", "\n", "## Introduction to Solvers\n", "\n", "[`Solvers`](https://mathlab.github.io/PINA/_rst/_code.html#solvers) are versatile objects in PINA designed to manage the training and optimization of machine learning models. They handle key components of the learning process, including:\n", "\n", "- Loss function minimization \n", "- Model optimization (optimizer, schedulers)\n", "- Validation and testing workflows\n", "\n", "PINA solvers are built on top of the [PyTorch Lightning `LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html), which provides a structured and scalable training framework. This allows solvers to leverage advanced features such as distributed training, early stopping, and logging — all with minimal setup.\n", "\n", "## Solvers Hierarchy: Single and MultiSolver\n", "\n", "PINA provides two main abstract interfaces for solvers, depending on whether the training involves a single model or multiple models. These interfaces define the base functionality that all specific solver implementations inherit from.\n", "\n", "### 1. [`SingleSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/solver_interface.html)\n", "\n", "This is the abstract base class for solvers that train **a single model**, such as in standard supervised learning or physics-informed training. All specific solvers (e.g., `SupervisedSolver`, `PINN`) inherit from this interface.\n", "\n", "**Arguments:**\n", "- `problem` – The problem to be solved.\n", "- `model` – The neural network model.\n", "- `optimizer` – Defaults to `torch.optim.Adam` if not provided.\n", "- `scheduler` – Defaults to `torch.optim.lr_scheduler.ConstantLR`.\n", "- `weighting` – Optional loss weighting schema., see [here](https://mathlab.github.io/PINA/_rst/_code.html#losses-and-weightings). We weight already for you!\n", "- `use_lt` – Whether to use LabelTensors as input.\n", "\n", "---\n", "\n", "### 2. [`MultiSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/multi_solver_interface.html)\n", "\n", "This is the abstract base class for solvers involving **multiple models**, such as in GAN architectures or ensemble training strategies. All multi-model solvers (e.g., `DeepEnsemblePINN`, `GAROM`) inherit from this interface.\n", "\n", "**Arguments:**\n", "- `problem` – The problem to be solved.\n", "- `models` – The model or models used for training.\n", "- `optimizers` – Defaults to `torch.optim.Adam`.\n", "- `schedulers` – Defaults to `torch.optim.lr_scheduler.ConstantLR`.\n", "- `weightings` – Optional loss weighting schema, see [here](https://mathlab.github.io/PINA/_rst/_code.html#losses-and-weightings). We weight already for you!\n", "- `use_lt` – Whether to use LabelTensors as input.\n", "\n", "---\n", "\n", "These base classes define the structure and behavior of solvers in PINA, allowing you to create customized training strategies while leveraging PyTorch Lightning's features under the hood. \n", "\n", "These classes are used to define the backbone, i.e. setting the problem, the model(s), the optimizer(s) and scheduler(s), but miss a key component the `optimization_cycle` method.\n", "\n", "\n", "## Optimization Cycle\n", "The `optimization_cycle` method is the core function responsible for computing losses for **all conditions** in a given training batch. Each condition (e.g. initial condition, boundary condition, PDE residual) contributes its own loss, which is tracked and returned in a dictionary. This method should return a dictionary mapping **condition names** to their respective **scalar loss values**.\n", "\n", "For supervised learning tasks, where each condition consists of an input-target pair, for example, the `optimization_cycle` may look like this:\n", "\n", "```python\n", "def optimization_cycle(self, batch):\n", " \"\"\"\n", " The optimization cycle for Supervised solvers.\n", " Computes loss for each condition in the batch.\n", " \"\"\"\n", " condition_loss = {}\n", " for condition_name, data in batch:\n", " condition_loss[condition_name] = self.loss_data(\n", " input=data[\"input\"], target=data[\"target\"]\n", " )\n", " return condition_loss\n", "```\n", "In PINA, a **batch** is structured as a list of tuples, where each tuple corresponds to a specific training condition. Each tuple contains:\n", "\n", "- The **name of the condition**\n", "- A **dictionary of data** associated with that condition\n", "\n", "for example:\n", "\n", "```python\n", "batch = [\n", " (\"condition1\", {\"input\": ..., \"target\": ...}),\n", " (\"condition2\", {\"input\": ..., \"equation\": ...}),\n", " (\"condition3\", {\"input\": ..., \"target\": ...}),\n", "]\n", "```\n", "\n", "Fortunately, you don't need to implement the `optimization_cycle` yourself in most cases — PINA already provides default implementations tailored to common solver types. These implementations are available through the solver interfaces and cover various training strategies.\n", "\n", "1. [`PINNInterface`](https://mathlab.github.io/PINA/_rst/solver/physics_informed_solver/pinn_interface.html) \n", " Implements the optimization cycle for **physics-based solvers** (e.g., PDE residual minimization) as well as other useful methods to compute PDE residuals. \n", " ➤ [View method](https://mathlab.github.io/PINA/_rst/solver/physics_informed_solver/pinn_interface.html#pina.solver.physics_informed_solver.pinn_interface.PINNInterface.optimization_cycle)\n", "\n", "2. [`SupervisedSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/supervised_solver/supervised_solver_interface.html) \n", " Defines the optimization cycle for **supervised learning tasks**, including traditional regression and classification. \n", " ➤ [View method](https://mathlab.github.io/PINA/_rst/solver/supervised_solver/supervised_solver_interface.html#pina.solver.supervised_solver.supervised_solver_interface.SupervisedSolverInterface.optimization_cycle)\n", "\n", "3. [`DeepEnsembleSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/ensemble_solver/ensemble_solver_interface.html) \n", " Provides the optimization logic for **deep ensemble methods**, commonly used for uncertainty quantification or robustness. \n", " ➤ [View method](https://mathlab.github.io/PINA/_rst/solver/ensemble_solver/ensemble_solver_interface.html#pina.solver.ensemble_solver.ensemble_solver_interface.DeepEnsembleSolverInterface.optimization_cycle)\n", "\n", "These ready-to-use implementations ensure that your solvers are properly structured and compatible with PINA’s training workflow. You can also inherit and override them to fit more specialized needs. They only require, the following arguments:\n", "**Arguments:**\n", "- `problem` – The problem to be solved.\n", "- `loss` - The loss to be minimized\n", "- `weightings` – Optional loss weighting schema.\n", "- `use_lt` – Whether to use LabelTensors as input.\n", "\n", "## Structure a Solver with Multiple Inheritance:\n", "\n", "Thanks to PINA’s modular design, creating a custom solver is straightforward using **multiple inheritance**. You can combine different interfaces to define both the **optimization logic** and the **model structure**.\n", "\n", "- **`PINN` Solver**\n", " - Inherits from: \n", " - [`PINNInterface`](https://mathlab.github.io/PINA/_rst/solver/physics_informed_solver/pinn_interface.html) → physics-based optimization loop \n", " - [`SingleSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/solver_interface.html) → training a single model\n", "\n", "- **`SupervisedSolver`**\n", " - Inherits from: \n", " - [`SupervisedSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/supervised_solver/supervised_solver_interface.html) → data-driven optimization loop \n", " - [`SingleSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/solver_interface.html) → training a single model\n", "\n", "- **`GAROM`** (a variant of GAN)\n", " - Inherits from: \n", " - [`SupervisedSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/supervised_solver/supervised_solver_interface.html) → data-driven optimization loop \n", " - [`MultiSolverInterface`](https://mathlab.github.io/PINA/_rst/solver/multi_solver_interface.html) → training multiple models (e.g., generator and discriminator)\n", "\n", "This structure promotes **code reuse** and **extensibility**, allowing you to quickly prototype new solver strategies by reusing core training and optimization logic.\n", "\n", "## Let's try to build some solvers!\n", "\n", "We will now start building a simple supervised solver in PINA. Let's first import useful modules! " ] }, { "cell_type": "code", "execution_count": null, "id": "0981f1e9", "metadata": {}, "outputs": [], "source": [ "## routine needed to run the notebook on Google Colab\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "if IN_COLAB:\n", " !pip install \"pina-mathlab[tutorial]\"\n", "\n", "import warnings\n", "import torch\n", "import matplotlib.pyplot as plt\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "from pina import Trainer\n", "from pina.solver import SingleSolverInterface, SupervisedSolverInterface\n", "from pina.model import FeedForward\n", "from pina.problem.zoo import SupervisedProblem" ] }, { "cell_type": "markdown", "id": "7b91de38", "metadata": {}, "source": [ "Since we are using only one model for this task, we will inherit from two base classes:\n", "\n", "- `SingleSolverInterface`: This ensures we are working with a single model.\n", "- `SupervisedSolverInterface`: This allows us to use supervised learning strategies for training the model." ] }, { "cell_type": "code", "execution_count": 50, "id": "014bbd86", "metadata": {}, "outputs": [], "source": [ "class MyFirstSolver(SupervisedSolverInterface, SingleSolverInterface):\n", " def __init__(\n", " self,\n", " problem,\n", " model,\n", " loss=None,\n", " optimizer=None,\n", " scheduler=None,\n", " weighting=None,\n", " use_lt=True,\n", " ):\n", " super().__init__(\n", " model=model,\n", " problem=problem,\n", " loss=loss,\n", " optimizer=optimizer,\n", " scheduler=scheduler,\n", " weighting=weighting,\n", " use_lt=use_lt,\n", " )" ] }, { "cell_type": "markdown", "id": "b1b1e4c4", "metadata": {}, "source": [ "By default, Python follows a specific method resolution order (MRO) when a class inherits from multiple parent classes. This means that the initialization (`__init__`) method is called based on the order of inheritance.\n", "\n", "Since we inherit from `SupervisedSolverInterface` first, Python will call the `__init__` method from `SupervisedSolverInterface` (initialize `problem`, `loss`, `weighting` and `use_lt`) before calling the `__init__` method from `SingleSolverInterface` (initialize `model`, `optimizer`, `scheduler`). This allows us to customize the initialization process for our custom solver. \n", "\n", "We will learn a very simple problem, try to learn $y=\\sin(x)$." ] }, { "cell_type": "code", "execution_count": 51, "id": "6f25d3a6", "metadata": {}, "outputs": [], "source": [ "# get the data\n", "x = torch.linspace(0, torch.pi, 100).view(-1, 1)\n", "y = torch.sin(x)\n", "# build the problem\n", "problem = SupervisedProblem(x, y)\n", "# build the model\n", "model = FeedForward(1, 1)" ] }, { "cell_type": "markdown", "id": "9f7551bf", "metadata": {}, "source": [ "If we now try to initialize the solver `MyFirstSolver` we will get the following error:\n", "\n", "```python\n", "---------------------------------------------------------------------------\n", "TypeError Traceback (most recent call last)\n", "Cell In[41], line 1\n", "----> 1 MyFirstSolver(problem, model)\n", "\n", "TypeError: Can't instantiate abstract class MyFirstSolver with abstract method loss_data\n", "```\n", "\n", "### Data and Physics Loss\n", "The error above is because in PINA, all solvers must specify how to compute the loss during training. There are two main types of losses that can be computed, depending on the nature of the problem:\n", "\n", "1. **`loss_data`**: Computes the **data loss** between the model's output and the true solution. This is typically used in **supervised learning** setups, where we have ground truth data to compare the model's predictions. It expects some `input` (tensor, graph, ...) and a `target` (tensor, graph, ...)\n", " \n", "2. **`loss_phys`**: Computes the **physics loss** for **physics-informed solvers** (PINNs). This loss is based on the residuals of the governing equations that model physical systems, enforcing the equations during training. It expects some `samples` (`LabelTensor`) and an `equation` (`Equation`)\n", "\n", "Therefore our implementation becomes:" ] }, { "cell_type": "code", "execution_count": 52, "id": "336e8060", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "\n", " | Name | Type | Params | Mode \n", "----------------------------------------------------\n", "0 | _pina_models | ModuleList | 481 | train\n", "1 | _loss_fn | MSELoss | 0 | train\n", "----------------------------------------------------\n", "481 Trainable params\n", "0 Non-trainable params\n", "481 Total params\n", "0.002 Total estimated model params size (MB)\n", "9 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6d009cd7efb4c76ba2115f828e46dc8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: | | 0/? [00:00