{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial: Chemical Properties Prediction with Graph Neural Networks\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial15/tutorial.ipynb)\n", "\n", "In this tutorial we will use **Graph Neural Networks** (GNNs) for chemical properties prediction. Chemical properties prediction involves estimating or determining the physical, chemical, or biological characteristics of molecules based on their structure. \n", "\n", "Molecules can naturally be represented as graphs, where atoms serve as the nodes and chemical bonds as the edges connecting them. This graph-based structure makes GNNs a great fit for predicting chemical properties.\n", "\n", "In the tutorial we will use the [QM9 dataset](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.QM9.html#torch_geometric.datasets.QM9) from Pytorch Geometric. The dataset contains small molecules, each consisting of up to 29 atoms, with every atom having a corresponding 3D position. Each atom is also represented by a five-dimensional one-hot encoded vector that indicates the atom type (H, C, N, O, F).\n", "\n", "First of all, let's start by importing useful modules!" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "## routine needed to run the notebook on Google Colab\n", "try:\n", " import google.colab\n", "\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "if IN_COLAB:\n", " !pip install \"pina-mathlab[tutorial]\"\n", "\n", "import torch\n", "import warnings\n", "\n", "from pina import Trainer\n", "from pina.solver import SupervisedSolver\n", "from pina.problem.zoo import SupervisedProblem\n", "\n", "from torch_geometric.datasets import QM9\n", "from torch_geometric.nn import GCNConv, global_mean_pool\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download Data and create the Problem" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We download the dataset and save the molecules as a list of `Data` objects (`input_`), where each element contains one molecule encoded in a graph structure. The corresponding target properties (`target_`) are listed below:\n", "\n", "| Target | Property | Description | Unit |\n", "|--------|----------------------------------|-----------------------------------------------------------------------------------|---------------------------------------------|\n", "| 0 | $\\mu$ | Dipole moment | $D$ |\n", "| 1 | $\\alpha$ | Isotropic polarizability | $a₀³$ |\n", "| 2 | $\\epsilon_{\\textrm{HOMO}}$ | Highest occupied molecular orbital energy | $eV$ |\n", "| 3 | $\\epsilon_{\\textrm{LUMO}}$ | Lowest unoccupied molecular orbital energy | $eV$ |\n", "| 4 | $\\Delta \\epsilon$ | Gap between $\\epsilon_{\\textrm{HOMO}}$ and $\\epsilon_{\\textrm{LUMO}}$ | $eV$ |\n", "| 5 | $\\langle R^2 \\rangle$ | Electronic spatial extent | $a₀²$ |\n", "| 6 | $\\textrm{ZPVE}$ | Zero point vibrational energy | $eV$ |\n", "| 7 | $U_0$ | Internal energy at 0K | $eV$ |\n", "| 8 | $U$ | Internal energy at 298.15K | $eV$ |\n", "| 9 | $H$ | Enthalpy at 298.15K | $eV$ |\n", "| 10 | $G$ | Free energy at 298.15K | $eV$ |\n", "| 11 | $c_{\\textrm{v}}$ | Heat capacity at 298.15K | $cal/(mol·K)$ |\n", "| 12 | $U_0^{\\textrm{ATOM}}$ | Atomization energy at 0K | $eV$ |\n", "| 13 | $U^{\\textrm{ATOM}}$ | Atomization energy at 298.15K | $eV$ |\n", "| 14 | $H^{\\textrm{ATOM}}$ | Atomization enthalpy at 298.15K | $eV$ |\n", "| 15 | $G^{\\textrm{ATOM}}$ | Atomization free energy at 298.15K | $eV$ |\n", "| 16 | $A$ | Rotational constant | $GHz$ |\n", "| 17 | $B$ | Rotational constant | $GHz$ |\n", "| 18 | $C$ | Rotational constant | $GHz$ |\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# download the data + shuffling\n", "dataset = QM9(root=\"./tutorial_logs\").shuffle()\n", "\n", "# save the dataset\n", "input_ = [data for data in dataset]\n", "target_ = torch.cat([data.y for data in dataset])\n", "\n", "# normalize the target\n", "mean = target_.mean(dim=0, keepdim=True)\n", "std = target_.std(dim=0, keepdim=True)\n", "target_ = (target_ - mean) / std" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Great! Once the data are downloaded, building the problem is straightforward by using the [`SupervisedProblem`](https://mathlab.github.io/PINA/_rst/problem/zoo/supervised_problem.html) class." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# build the problem\n", "problem = SupervisedProblem(input_=input_, output_=target_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build the Model\n", "\n", "To predict molecular properties, we will construct a simple Convolutional Graph Neural Network using the [`GCNConv`]() module from PyG. While this tutorial focuses on a straightforward model, more advanced architectures—such as Equivariant Networks—could potentially yield better performance. Please note that this tutorial serves only for demonstration purposes.\n", "\n", "**Importantly** notice that in the `forward` pass we pass a data object as input, and unpack inside the graph attributes. This is the only requirement in **PINA** to use graphs and solvers together." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class GNN(torch.nn.Module):\n", " def __init__(self, in_features, out_features, hidden_dim=256):\n", " super(GNN, self).__init__()\n", " self.conv1 = GCNConv(in_features, hidden_dim)\n", " self.conv2 = GCNConv(hidden_dim, hidden_dim)\n", " self.fc = torch.nn.Linear(hidden_dim, out_features)\n", "\n", " def forward(self, data):\n", " # extract attributes, N.B. in PINA Data object are passed as input\n", " x, edge_index, batch = data.x, data.edge_index, data.batch\n", " # perform normal graph operations\n", " x = torch.relu(self.conv1(x, edge_index))\n", " x = torch.relu(self.conv2(x, edge_index))\n", " x = global_mean_pool(x, batch)\n", " return self.fc(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the Model\n", "\n", "Now that the problem is created and the model is built, we can train the model using the [`SupervisedSolver`](https://mathlab.github.io/PINA/_rst/solver/supervised.html), which is the solver for standard supervised learning task. We will optimize the Maximum Absolute Error and test on the same metric. In the [`Trainer`](https://mathlab.github.io/PINA/_rst/trainer.html) class we specify the optimization hyperparameters." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8a20671419f04a7787981ecd5d637e4d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: | | 0/? [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Set up the plot grid\n", "num_properties = 19\n", "fig, axes = plt.subplots(4, 5, figsize=(10, 8))\n", "axes = axes.flatten()\n", "\n", "# Outlier removal using IQR (with torch)\n", "for idx in range(num_properties):\n", " target_vals = target_test[:, idx]\n", " pred_vals = prediction_test[:, idx]\n", "\n", " # Calculate Q1 (25th percentile) and Q3 (75th percentile) using torch\n", " Q1 = torch.quantile(target_vals, 0.25)\n", " Q3 = torch.quantile(target_vals, 0.75)\n", " IQR = Q3 - Q1\n", "\n", " # Define the outlier range\n", " lower_bound = Q1 - 1.5 * IQR\n", " upper_bound = Q3 + 1.5 * IQR\n", "\n", " # Filter out the outliers\n", " mask = (target_vals >= lower_bound) & (target_vals <= upper_bound)\n", " filtered_target = target_vals[mask]\n", " filtered_pred = pred_vals[mask]\n", "\n", " # Plotting\n", " ax = axes[idx]\n", " ax.scatter(\n", " filtered_target.detach(),\n", " filtered_pred.detach(),\n", " alpha=0.5,\n", " label=\"Data points (no outliers)\",\n", " )\n", " ax.plot(\n", " [filtered_target.min().item(), filtered_target.max().item()],\n", " [filtered_target.min().item(), filtered_target.max().item()],\n", " \"r--\",\n", " label=\"y=x\",\n", " )\n", "\n", " ax.set_title(properties[idx])\n", " ax.set_xlabel(\"Target\")\n", " ax.set_ylabel(\"Prediction\")\n", "\n", "# Remove the extra subplot (since there are 19 properties, not 20)\n", "if num_properties < len(axes):\n", " fig.delaxes(axes[-1])\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By looking more into details, we can see that $A$ is not predicted that well, but the small values of the quantity lead to a lower MAE than the other properties. From the plot we can see that the atomatization energies, free energy and enthalpy are the predicted properties with higher correlation with the true chemical properties." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What's Next?\n", "\n", "Congratulations on completing the tutorial on chemical properties prediction with **PINA**! Now that you've got the basics, there are several exciting directions to explore:\n", "\n", "1. **Train the network for longer or with different layer sizes**: Experiment with various configurations to see how the network's accuracy improves.\n", "\n", "2. **Use a different network**: For example, Equivariant Graph Neural Networks (EGNNs) have shown great results on molecular tasks by leveraging group symmetries. If you're interested, check out [*E(n) Equivariant Graph Neural Networks*](https://arxiv.org/abs/2102.09844) for more details.\n", "\n", "3. **What if the input is time-dependent?**: For example, predicting force fields in Molecular Dynamics simulations. In PINA, you can predict force fields with ease, as it's still a supervised learning task. If this interests you, have a look at [*Machine Learning Force Fields*](https://pubs.acs.org/doi/10.1021/acs.chemrev.0c01111).\n", "\n", "4. **...and many more!**: The possibilities are vast, including exploring new architectures, working with larger datasets, and applying this framework to more complex systems.\n", "\n", "For more resources and tutorials, check out the [PINA Documentation](https://mathlab.github.io/PINA/)." ] } ], "metadata": { "kernelspec": { "display_name": "pina", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 2 }