update tutorial3

This commit is contained in:
Dario Coscia
2025-03-07 18:26:18 +01:00
committed by Nicola Demo
parent dc71d328cf
commit 7ef39f1e3b
2 changed files with 422 additions and 305 deletions

View File

@@ -32,16 +32,19 @@
" \n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"plt.style.use('tableau-colorblind10')\n",
"import warnings\n",
"\n",
"from pina import Condition, LabelTensor\n",
"from pina.problem import SpatialProblem, TimeDependentProblem\n",
"from pina.operator import laplacian, grad\n",
"from pina.domain import CartesianDomain\n",
"from pina.solver import PINN\n",
"from pina.trainer import Trainer\n",
"from pina.equation import Equation\n",
"from pina.equation.equation_factory import FixedValue\n",
"from pina import Condition, LabelTensor"
"from pina.equation import Equation, FixedValue\n",
"\n",
"from lightning.pytorch.loggers import TensorBoardLogger\n",
"\n",
"warnings.filterwarnings('ignore')"
]
},
{
@@ -86,37 +89,61 @@
"outputs": [],
"source": [
"class Wave(TimeDependentProblem, SpatialProblem):\n",
" output_variables = ['u']\n",
" spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})\n",
" temporal_domain = CartesianDomain({'t': [0, 1]})\n",
" output_variables = [\"u\"]\n",
" spatial_domain = CartesianDomain({\"x\": [0, 1], \"y\": [0, 1]})\n",
" temporal_domain = CartesianDomain({\"t\": [0, 1]})\n",
"\n",
" def wave_equation(input_, output_):\n",
" u_t = grad(output_, input_, components=['u'], d=['t'])\n",
" u_tt = grad(u_t, input_, components=['dudt'], d=['t'])\n",
" nabla_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])\n",
" u_t = grad(output_, input_, components=[\"u\"], d=[\"t\"])\n",
" u_tt = grad(u_t, input_, components=[\"dudt\"], d=[\"t\"])\n",
" nabla_u = laplacian(output_, input_, components=[\"u\"], d=[\"x\", \"y\"])\n",
" return nabla_u - u_tt\n",
"\n",
" def initial_condition(input_, output_):\n",
" u_expected = (torch.sin(torch.pi*input_.extract(['x'])) *\n",
" torch.sin(torch.pi*input_.extract(['y'])))\n",
" return output_.extract(['u']) - u_expected\n",
" u_expected = torch.sin(torch.pi * input_.extract([\"x\"])) * torch.sin(\n",
" torch.pi * input_.extract([\"y\"])\n",
" )\n",
" return output_.extract([\"u\"]) - u_expected\n",
"\n",
" conditions = {\n",
" 'bound_cond1': Condition(domain=CartesianDomain({'x': [0, 1], 'y': 1, 't': [0, 1]}), equation=FixedValue(0.)),\n",
" 'bound_cond2': Condition(domain=CartesianDomain({'x': [0, 1], 'y': 0, 't': [0, 1]}), equation=FixedValue(0.)),\n",
" 'bound_cond3': Condition(domain=CartesianDomain({'x': 1, 'y': [0, 1], 't': [0, 1]}), equation=FixedValue(0.)),\n",
" 'bound_cond4': Condition(domain=CartesianDomain({'x': 0, 'y': [0, 1], 't': [0, 1]}), equation=FixedValue(0.)),\n",
" 'time_cond': Condition(domain=CartesianDomain({'x': [0, 1], 'y': [0, 1], 't': 0}), equation=Equation(initial_condition)),\n",
" 'phys_cond': Condition(domain=CartesianDomain({'x': [0, 1], 'y': [0, 1], 't': [0, 1]}), equation=Equation(wave_equation)),\n",
" \"bound_cond1\": Condition(\n",
" domain=CartesianDomain({\"x\": [0, 1], \"y\": 1, \"t\": [0, 1]}),\n",
" equation=FixedValue(0.0),\n",
" ),\n",
" \"bound_cond2\": Condition(\n",
" domain=CartesianDomain({\"x\": [0, 1], \"y\": 0, \"t\": [0, 1]}),\n",
" equation=FixedValue(0.0),\n",
" ),\n",
" \"bound_cond3\": Condition(\n",
" domain=CartesianDomain({\"x\": 1, \"y\": [0, 1], \"t\": [0, 1]}),\n",
" equation=FixedValue(0.0),\n",
" ),\n",
" \"bound_cond4\": Condition(\n",
" domain=CartesianDomain({\"x\": 0, \"y\": [0, 1], \"t\": [0, 1]}),\n",
" equation=FixedValue(0.0),\n",
" ),\n",
" \"time_cond\": Condition(\n",
" domain=CartesianDomain({\"x\": [0, 1], \"y\": [0, 1], \"t\": 0}),\n",
" equation=Equation(initial_condition),\n",
" ),\n",
" \"phys_cond\": Condition(\n",
" domain=CartesianDomain({\"x\": [0, 1], \"y\": [0, 1], \"t\": [0, 1]}),\n",
" equation=Equation(wave_equation),\n",
" ),\n",
" }\n",
"\n",
" def wave_sol(self, pts):\n",
" return (torch.sin(torch.pi*pts.extract(['x'])) *\n",
" torch.sin(torch.pi*pts.extract(['y'])) *\n",
" torch.cos(torch.sqrt(torch.tensor(2.))*torch.pi*pts.extract(['t'])))\n",
" def truth_solution(self, pts):\n",
" f = (\n",
" torch.sin(torch.pi * pts.extract([\"x\"]))\n",
" * torch.sin(torch.pi * pts.extract([\"y\"]))\n",
" * torch.cos(\n",
" torch.sqrt(torch.tensor(2.0)) * torch.pi * pts.extract([\"t\"])\n",
" )\n",
" )\n",
" return LabelTensor(f, self.output_variables)\n",
"\n",
" truth_solution = wave_sol\n",
"\n",
"# define problem\n",
"problem = Wave()"
]
},
@@ -152,16 +179,23 @@
" def __init__(self, input_dim, output_dim):\n",
" super().__init__()\n",
"\n",
" self.layers = torch.nn.Sequential(torch.nn.Linear(input_dim, 40),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(40, 40),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(40, output_dim))\n",
" \n",
" self.layers = torch.nn.Sequential(\n",
" torch.nn.Linear(input_dim, 40),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(40, 40),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(40, output_dim),\n",
" )\n",
"\n",
" # here in the foward we implement the hard constraints\n",
" def forward(self, x):\n",
" hard = x.extract(['x'])*(1-x.extract(['x']))*x.extract(['y'])*(1-x.extract(['y']))\n",
" return hard*self.layers(x)"
" hard = (\n",
" x.extract([\"x\"])\n",
" * (1 - x.extract([\"x\"]))\n",
" * x.extract([\"y\"])\n",
" * (1 - x.extract([\"y\"]))\n",
" )\n",
" return hard * self.layers(x)"
]
},
{
@@ -177,7 +211,7 @@
"id": "b465bebd",
"metadata": {},
"source": [
"In this tutorial, the neural network is trained for 1000 epochs with a learning rate of 0.001 (default in `PINN`). Training takes approximately 3 minutes."
"In this tutorial, the neural network is trained for 1000 epochs with a learning rate of 0.001 (default in `PINN`). As always, we will log using `Tensorboard`."
]
},
{
@@ -188,22 +222,66 @@
"outputs": [],
"source": [
"# generate the data\n",
"problem.discretise_domain(1000, 'random', domains=['phys_cond', 'time_cond', 'bound_cond1', 'bound_cond2', 'bound_cond3', 'bound_cond4'])\n",
"problem.discretise_domain(\n",
" 1000,\n",
" \"random\",\n",
" domains=[\n",
" \"phys_cond\",\n",
" \"time_cond\",\n",
" \"bound_cond1\",\n",
" \"bound_cond2\",\n",
" \"bound_cond3\",\n",
" \"bound_cond4\",\n",
" ],\n",
")\n",
"\n",
"# create the solver\n",
"pinn = PINN(problem, HardMLP(len(problem.input_variables), len(problem.output_variables)))\n",
"# define model\n",
"model = HardMLP(len(problem.input_variables), len(problem.output_variables))\n",
"\n",
"# crete the solver\n",
"pinn = PINN(problem=problem, model=model)\n",
"\n",
"# create trainer and train\n",
"trainer = Trainer(pinn, max_epochs=1000, accelerator='cpu', enable_model_summary=False) # we train on CPU and avoid model summary at beginning of training (optional)\n",
"trainer = Trainer(\n",
" solver=pinn,\n",
" max_epochs=1000,\n",
" accelerator=\"cpu\",\n",
" enable_model_summary=False,\n",
" train_size=1.0,\n",
" val_size=0.0,\n",
" test_size=0.0,\n",
" logger=TensorBoardLogger(\"tutorial_logs\"),\n",
" enable_progress_bar=False,\n",
")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "4c6dbfac",
"metadata": {},
"source": [
"Let's now plot the logging to see how the losses vary during training. For this, we will use `TensorBoard`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77bfcb6e",
"metadata": {},
"outputs": [],
"source": [
"# Load the TensorBoard extension\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir 'tutorial_logs'\n"
]
},
{
"cell_type": "markdown",
"id": "c2a5c405",
"metadata": {},
"source": [
"Notice that the loss on the boundaries of the spatial domain is exactly zero, as expected! After the training is completed one can now plot some results using the `matplotlib`. We plot the predicted output on the left side, the true solution at the center and the difference on the right side."
"Notice that the loss on the boundaries of the spatial domain is exactly zero, as expected! After the training is completed one can now plot some results using the `matplotlib`. We plot the predicted output on the left side, the true solution at the center and the difference on the right side using the `plot_solution` function."
]
},
{
@@ -213,32 +291,35 @@
"metadata": {},
"outputs": [],
"source": [
"def fixed_time_plot(fixed_variables, pinn):\n",
" #sample domain points and get values corresponding to fixed variables\n",
" pts = pinn.problem.spatial_domain.sample(256, 'grid', variables=['x','y'])\n",
" grids = [p_.reshape(256, 256) for p_ in pts.extract(['x','y']).T]\n",
" fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))\n",
" fixed_pts *= torch.tensor(list(fixed_variables.values()))\n",
" fixed_pts = fixed_pts.as_subclass(LabelTensor)\n",
" fixed_pts.labels = list(fixed_variables.keys())\n",
" pts = pts.append(fixed_pts).to(device=pinn.device)\n",
" predicted_output = pinn.forward(pts).extract('u').as_subclass(torch.Tensor).cpu().detach().reshape(256,256)\n",
" #get true solution\n",
" true_output = pinn.problem.truth_solution(pts).cpu().detach().reshape(256,256)\n",
" pts = pts.cpu()\n",
" #plot prediction, true solution and difference\n",
" grids = [p_.reshape(256, 256) for p_ in pts.extract(['x','y']).T]\n",
" fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))\n",
" cb = getattr(ax[0], 'contourf')(*grids, predicted_output)\n",
" fig.colorbar(cb, ax=ax[0])\n",
" ax[0].title.set_text('Neural Network prediction')\n",
" cb = getattr(ax[1], 'contourf')(*grids, true_output)\n",
" fig.colorbar(cb, ax=ax[1])\n",
" ax[1].title.set_text('True solution')\n",
" cb = getattr(ax[2],'contourf')(*grids,(true_output - predicted_output))\n",
" fig.colorbar(cb, ax=ax[2])\n",
" ax[2].title.set_text('Residual')\n",
" plt.show(block=True)"
"@torch.no_grad()\n",
"def plot_solution(solver, time):\n",
" # get the problem\n",
" problem = solver.problem\n",
" # get spatial points\n",
" spatial_samples = problem.spatial_domain.sample(30, \"grid\")\n",
" # get temporal value\n",
" time = LabelTensor(torch.tensor([[time]]), \"t\")\n",
" # cross data\n",
" points = spatial_samples.append(time, mode=\"cross\")\n",
" # compute pinn solution, true solution and absolute difference\n",
" data = {\n",
" \"PINN solution\": solver(points),\n",
" \"True solution\": problem.truth_solution(points),\n",
" \"Absolute Difference\": torch.abs(\n",
" solver(points) - problem.truth_solution(points)\n",
" )\n",
" }\n",
" # plot the solution\n",
" plt.suptitle(f'Solution for time {time.item()}')\n",
" for idx, (title, field) in enumerate(data.items()):\n",
" plt.subplot(1, 3, idx + 1)\n",
" plt.title(title)\n",
" plt.tricontourf( # convert to torch tensor + flatten\n",
" points.extract(\"x\").tensor.flatten(),\n",
" points.extract(\"y\").tensor.flatten(),\n",
" field.tensor.flatten(),\n",
" )\n",
" plt.colorbar(), plt.tight_layout()"
]
},
{
@@ -256,12 +337,14 @@
"metadata": {},
"outputs": [],
"source": [
"print('Plotting at t=0')\n",
"fixed_time_plot(fixed_variables={'t':0.0},pinn=pinn)\n",
"print('Plotting at t=0.5')\n",
"fixed_time_plot(fixed_variables={'t':0.5},pinn=pinn)\n",
"print('Plotting at t=1.0')\n",
"fixed_time_plot(fixed_variables={'t':1.0},pinn=pinn)"
"plt.figure(figsize=(12, 6))\n",
"plot_solution(solver=pinn, time=0)\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"plot_solution(solver=pinn, time=0.5)\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"plot_solution(solver=pinn, time=1)"
]
},
{
@@ -290,17 +373,30 @@
" def __init__(self, input_dim, output_dim):\n",
" super().__init__()\n",
"\n",
" self.layers = torch.nn.Sequential(torch.nn.Linear(input_dim, 40),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(40, 40),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(40, output_dim))\n",
" \n",
" self.layers = torch.nn.Sequential(\n",
" torch.nn.Linear(input_dim, 40),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(40, 40),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(40, output_dim),\n",
" )\n",
"\n",
" # here in the foward we implement the hard constraints\n",
" def forward(self, x):\n",
" hard_space = x.extract(['x'])*(1-x.extract(['x']))*x.extract(['y'])*(1-x.extract(['y']))\n",
" hard_t = torch.sin(torch.pi*x.extract(['x'])) * torch.sin(torch.pi*x.extract(['y'])) * torch.cos(torch.sqrt(torch.tensor(2.))*torch.pi*x.extract(['t']))\n",
" return hard_space * self.layers(x) * x.extract(['t']) + hard_t"
" hard_space = (\n",
" x.extract([\"x\"])\n",
" * (1 - x.extract([\"x\"]))\n",
" * x.extract([\"y\"])\n",
" * (1 - x.extract([\"y\"]))\n",
" )\n",
" hard_t = (\n",
" torch.sin(torch.pi * x.extract([\"x\"]))\n",
" * torch.sin(torch.pi * x.extract([\"y\"]))\n",
" * torch.cos(\n",
" torch.sqrt(torch.tensor(2.0)) * torch.pi * x.extract([\"t\"])\n",
" )\n",
" )\n",
" return hard_space * self.layers(x) * x.extract([\"t\"]) + hard_t"
]
},
{
@@ -308,7 +404,7 @@
"id": "5d3dc67b",
"metadata": {},
"source": [
"Now let's train with the same configuration as thre previous test"
"Now let's train with the same configuration as the previous test"
]
},
{
@@ -318,14 +414,24 @@
"metadata": {},
"outputs": [],
"source": [
"# generate the data\n",
"problem.discretise_domain(1000, 'random', domains=['phys_cond', 'time_cond', 'bound_cond1', 'bound_cond2', 'bound_cond3', 'bound_cond4'])\n",
"# define model\n",
"model = HardMLPtime(len(problem.input_variables), len(problem.output_variables))\n",
"\n",
"# crete the solver\n",
"pinn = PINN(problem, HardMLPtime(len(problem.input_variables), len(problem.output_variables)))\n",
"pinn = PINN(problem=problem, model=model)\n",
"\n",
"# create trainer and train\n",
"trainer = Trainer(pinn, max_epochs=1000, accelerator='cpu', enable_model_summary=False) # we train on CPU and avoid model summary at beginning of training (optional)\n",
"trainer = Trainer(\n",
" solver=pinn,\n",
" max_epochs=1000,\n",
" accelerator=\"cpu\",\n",
" enable_model_summary=False,\n",
" train_size=1.0,\n",
" val_size=0.0,\n",
" test_size=0.0,\n",
" logger=TensorBoardLogger(\"tutorial_logs\"),\n",
" enable_progress_bar=False,\n",
")\n",
"trainer.train()"
]
},
@@ -344,12 +450,14 @@
"metadata": {},
"outputs": [],
"source": [
"print('Plotting at t=0')\n",
"fixed_time_plot(fixed_variables={'t':0.0},pinn=pinn)\n",
"print('Plotting at t=0.5')\n",
"fixed_time_plot(fixed_variables={'t':0.5},pinn=pinn)\n",
"print('Plotting at t=1.0')\n",
"fixed_time_plot(fixed_variables={'t':1.0},pinn=pinn)\n"
"plt.figure(figsize=(12, 6))\n",
"plot_solution(solver=pinn, time=0)\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"plot_solution(solver=pinn, time=0.5)\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"plot_solution(solver=pinn, time=1)"
]
},
{
@@ -357,7 +465,17 @@
"id": "b7338109",
"metadata": {},
"source": [
"We can see now that the results are way better! This is due to the fact that previously the network was not learning correctly the initial conditon, leading to a poor solution when time evolved. By imposing the initial condition the network is able to correctly solve the problem."
"We can see now that the results are way better! This is due to the fact that previously the network was not learning correctly the initial conditon, leading to a poor solution when time evolved. By imposing the initial condition the network is able to correctly solve the problem. We can also see using Tensorboard how the two losses decreased:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ce34dac",
"metadata": {},
"outputs": [],
"source": [
"%tensorboard --logdir 'tutorial_logs'"
]
},
{
@@ -381,7 +499,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "pina",
"language": "python",
"name": "python3"
},
@@ -395,7 +513,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.9.21"
}
},
"nbformat": 4,