tutorial validation (#185)

Co-authored-by: Ben Volokh <89551265+benv123@users.noreply.github.com>
This commit is contained in:
Nicola Demo
2023-10-17 10:54:31 +02:00
parent 2e2fe93458
commit 32ff5de1f4
38 changed files with 1072 additions and 1006 deletions

View File

@@ -1,12 +1,11 @@
Tutorial 4: continuous convolutional filter
===========================================
In this tutorial we will show how to use the Continouous Convolutional
In this tutorial, we will show how to use the Continuous Convolutional
Filter, and how to build common Deep Learning architectures with it. The
implementation of the filter follows the original work `A Continuous
implementation of the filter follows the original work `**A Continuous
Convolutional Trainable Filter for Modelling Unstructured
Data <https://arxiv.org/abs/2210.13416>`__ of Coscia Dario, Laura
Meneghetti, Nicola Demo, Giovanni Stabile, and Gianluigi Rozza.
Data** <https://arxiv.org/abs/2210.13416>`__.
First of all we import the modules needed for the tutorial, which
include:
@@ -20,17 +19,20 @@ include:
import torch
import matplotlib.pyplot as plt
from pina.model.layers import ContinuousConv
from pina.model.layers import ContinuousConvBlock
import torchvision # for MNIST dataset
from pina.model import FeedForward # for building AE and MNIST classification
The tutorial is structured as follow:
* `Continuous filter background <#continuous-filter-background>`__: understand how the convolutional filter works and how to use it.
* `Building a MNIST Classifier <#building-a-mnist-classifier>`__: show how to build a simple classifier using the MNIST dataset and how to combine a continuous convolutional layer with a feedforward neural network.
* `Building a Continuous Convolutional Autoencoder <#building-a-continuous-convolutional-autoencoder>`__: show how to use the continuous filter to work with unstructured data for autoencoding and up-sampling.
The tutorial is structured as follow: \* `Continuous filter
background <#continuous-filter-background>`__: understand how the
convolutional filter works and how to use it. \* `Building a MNIST
Classifier <#building-a-mnist-classifier>`__: show how to build a simple
classifier using the MNIST dataset and how to combine a continuous
convolutional layer with a feedforward neural network. \* `Building a
Continuous Convolutional
Autoencoder <#building-a-continuous-convolutional-autoencoder>`__: show
how to use the continuous filter to work with unstructured data for
autoencoding and up-sampling.
Continuous filter background
----------------------------
@@ -44,7 +46,7 @@ as:
\mathcal{I}_{\rm{out}}(\mathbf{x}) = \int_{\mathcal{X}} \mathcal{I}(\mathbf{x} + \mathbf{\tau}) \cdot \mathcal{K}(\mathbf{\tau}) d\mathbf{\tau},
where :math:`\mathcal{K} : \mathcal{X} \rightarrow \mathbb{R}` is the
where :math:`\mathcal{K} : \mathcal{X} \rightarrow \mathbb{R}` is the
*continuous filter* function, and
:math:`\mathcal{I} : \Omega \subset \mathbb{R}^N \rightarrow \mathbb{R}`
is the input function. The continuous filter function is approximated
@@ -60,7 +62,7 @@ by the authors. Thus, given :math:`\{\mathbf{x}_i\}_{i=1}^{n}` points in
\mathcal{I}_{\rm{out}}(\mathbf{\tilde{x}}_i) = \sum_{{\mathbf{x}_i}\in\mathcal{X}} \mathcal{I}(\mathbf{x}_i + \mathbf{\tau}) \cdot \mathcal{K}(\mathbf{x}_i),
where :math:`\mathbf{\tau} \in \mathcal{S}`, with :math:`\mathcal{S}`
where :math:`\mathbf{\tau} \in \mathcal{S}`, with :math:`\mathcal{S}`
the set of available strides, corresponds to the current stride position
of the filter, and :math:`\mathbf{\tilde{x}}_i` points are obtained by
taking the centroid of the filter position mapped on the :math:`\Omega`
@@ -81,15 +83,17 @@ shape:
.. math:: [B \times N_{in} \times N \times D]
\ where :math:`B` is the batch_size, :math:`N_{in}` is the number of
where :math:`B` is the batch\_size, :math:`N_{in}` is the number of
input fields, :math:`N` the number of points in the mesh, :math:`D` the
dimension of the problem. In particular:
dimension of the problem. In particular: \* :math:`D` is the number of
spatial variables + 1. The last column must contain the field value. For
example for 2D problems :math:`D=3` and the tensor will be something
like ``[first coordinate, second coordinate, field value]`` \*
:math:`N_{in}` represents the number of vectorial function presented.
For example a vectorial function :math:`f = [f_1, f_2]` will have
:math:`N_{in}=2`
* :math:`D` is the number of spatial variables + 1. The last column must contain the field value. For example for 2D problems :math:`D=3` and the tensor will be something like ``[first coordinate, second coordinate, field value]``
* :math:`N_{in}` represents the number of vectorial function presented. For example a vectorial function :math:`f = [f_1, f_2]` will have math:`N_{in}=2`
Lets see an example to clear the ideas. We will be verbose to explain
Let's see an example to clear the ideas. We will be verbose to explain
in details the input form. We wish to create the function:
.. math::
@@ -144,22 +148,20 @@ where to go. Here is an example for the :math:`[0,1]\times[0,5]` domain:
.. code:: python
# stride definition
stride = {"domain": [1, 5],
"start": [0, 0],
"jump": [0.1, 0.3],
"direction": [1, 1],
}
# stride definition
stride = {"domain": [1, 5],
"start": [0, 0],
"jump": [0.1, 0.3],
"direction": [1, 1],
}
This tells the filter:
1. ``domain``: square domain (the only implemented) :math:`[0,1]\times[0,5]`. The minimum value is always zero, while the maximum is specified by the user
2. ``start``: start position of the filter, coordinate :math:`(0, 0)`
3. ``jump``: the jumps of the centroid of the filter to the next position :math:`(0.1, 0.3)`
4. ``direction``: the directions of the jump, with ``1 = right``, ``0 = no jump``,\ ``-1 = left`` with respect to the current position
This tells the filter: 1. ``domain``: square domain (the only
implemented) :math:`[0,1]\times[0,5]`. The minimum value is always zero,
while the maximum is specified by the user 2. ``start``: start position
of the filter, coordinate :math:`(0, 0)` 3. ``jump``: the jumps of the
centroid of the filter to the next position :math:`(0.1, 0.3)` 4.
``direction``: the directions of the jump, with ``1 = right``,
``0 = no jump``,\ ``-1 = left`` with respect to the current position
**Note**
@@ -188,30 +190,37 @@ fix the filter dimension to be :math:`[0.1, 0.1]`.
}
# creating the filter
cConv = ContinuousConv(input_numb_field=number_input_fileds,
cConv = ContinuousConvBlock(input_numb_field=number_input_fileds,
output_numb_field=1,
filter_dim=filter_dim,
stride=stride)
Thats it! In just one line of code we have created the continuous
.. parsed-literal::
/u/n/ndemo/.local/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
That's it! In just one line of code we have created the continuous
convolutional filter. By default the ``pina.model.FeedForward`` neural
network is intitialised, more on the
`documentation <https://mathlab.github.io/PINA/_rst/fnn.html>`__. In
case the mesh doesnt change during training we can set the ``optimize``
case the mesh doesn't change during training we can set the ``optimize``
flag equals to ``True``, to exploit optimizations for finding the points
to convolve.
.. code:: ipython3
# creating the filter + optimization
cConv = ContinuousConv(input_numb_field=number_input_fileds,
cConv = ContinuousConvBlock(input_numb_field=number_input_fileds,
output_numb_field=1,
filter_dim=filter_dim,
stride=stride,
optimize=True)
Lets try to do a forward pass
Let's try to do a forward pass
.. code:: ipython3
@@ -229,7 +238,7 @@ Lets try to do a forward pass
Filter output data has shape: torch.Size([1, 1, 169, 3])
If we dont want to use the default ``FeedForward`` neural network, we
If we don't want to use the default ``FeedForward`` neural network, we
can pass a specified torch model in the ``model`` keyword as follow:
.. code:: ipython3
@@ -248,7 +257,7 @@ can pass a specified torch model in the ``model`` keyword as follow:
return self.model(x)
cConv = ContinuousConv(input_numb_field=number_input_fileds,
cConv = ContinuousConvBlock(input_numb_field=number_input_fileds,
output_numb_field=1,
filter_dim=filter_dim,
stride=stride,
@@ -261,7 +270,7 @@ Notice that we pass the class and not an already built object!
Building a MNIST Classifier
---------------------------
Lets see how we can build a MNIST classifier using a continuous
Let's see how we can build a MNIST classifier using a continuous
convolutional filter. We will use the MNIST dataset from PyTorch. In
order to keep small training times we use only 6000 samples for training
and 1000 samples for testing.
@@ -299,7 +308,68 @@ and 1000 samples for testing.
test_loader = DataLoader(train_data, batch_size=batch_size,
sampler=SubsetRandomSampler(subsample_train_indices))
Lets now build a simple classifier. The MNIST dataset is composed by
.. parsed-literal::
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
.. parsed-literal::
100%|█████████████████████████████████| 9912422/9912422 [00:00<00:00, 59926793.62it/s]
.. parsed-literal::
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
.. parsed-literal::
100%|██████████████████████████████████████| 28881/28881 [00:00<00:00, 2463209.03it/s]
.. parsed-literal::
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
.. parsed-literal::
100%|█████████████████████████████████| 1648877/1648877 [00:00<00:00, 46499639.59it/s]
.. parsed-literal::
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
.. parsed-literal::
100%|███████████████████████████████████████| 4542/4542 [00:00<00:00, 19761959.30it/s]
.. parsed-literal::
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
.. parsed-literal::
Let's now build a simple classifier. The MNIST dataset is composed by
vectors of shape ``[batch, 1, 28, 28]``, but we can image them as one
field functions where the pixels :math:`ij` are the coordinate
:math:`x=i, y=j` in a :math:`[0, 27]\times[0,27]` domain, and the pixels
@@ -353,7 +423,7 @@ filter followed by a feedforward neural network
numb_class = 10
# convolutional block
self.convolution = ContinuousConv(input_numb_field=1,
self.convolution = ContinuousConvBlock(input_numb_field=1,
output_numb_field=4,
stride={"domain": [27, 27],
"start": [0, 0],
@@ -363,8 +433,8 @@ filter followed by a feedforward neural network
filter_dim=[4, 4],
optimize=True)
# feedforward net
self.nn = FeedForward(input_variables=196,
output_variables=numb_class,
self.nn = FeedForward(input_dimensions=196,
output_dimensions=numb_class,
layers=[120, 64],
func=torch.nn.ReLU)
@@ -378,7 +448,7 @@ filter followed by a feedforward neural network
net = ContinuousClassifier()
Lets try to train it using a simple pytorch training loop. We train for
Let's try to train it using a simple pytorch training loop. We train for
juts 1 epoch using Adam optimizer with a :math:`0.001` learning rate.
.. code:: ipython3
@@ -410,31 +480,37 @@ juts 1 epoch using Adam optimizer with a :math:`0.001` learning rate.
running_loss += loss.item()
if i % 50 == 49:
print(
f'epoch [{i + 1}/{numb_training//batch_size}] loss[{running_loss / 500:.3f}]')
f'batch [{i + 1}/{numb_training//batch_size}] loss[{running_loss / 500:.3f}]')
running_loss = 0.0
.. parsed-literal::
epoch [50/750] loss[0.148]
epoch [100/750] loss[0.072]
epoch [150/750] loss[0.063]
epoch [200/750] loss[0.053]
epoch [250/750] loss[0.041]
epoch [300/750] loss[0.048]
epoch [350/750] loss[0.054]
epoch [400/750] loss[0.048]
epoch [450/750] loss[0.047]
epoch [500/750] loss[0.035]
epoch [550/750] loss[0.036]
epoch [600/750] loss[0.041]
epoch [650/750] loss[0.030]
epoch [700/750] loss[0.040]
epoch [750/750] loss[0.040]
/u/n/ndemo/.local/lib/python3.9/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML
warnings.warn("Can't initialize NVML")
Lets see the performance on the train set!
.. parsed-literal::
batch [50/750] loss[0.161]
batch [100/750] loss[0.073]
batch [150/750] loss[0.063]
batch [200/750] loss[0.051]
batch [250/750] loss[0.044]
batch [300/750] loss[0.050]
batch [350/750] loss[0.053]
batch [400/750] loss[0.049]
batch [450/750] loss[0.046]
batch [500/750] loss[0.034]
batch [550/750] loss[0.036]
batch [600/750] loss[0.040]
batch [650/750] loss[0.028]
batch [700/750] loss[0.040]
batch [750/750] loss[0.040]
Let's see the performance on the train set!
.. code:: ipython3
@@ -457,11 +533,11 @@ Lets see the performance on the train set!
.. parsed-literal::
Accuracy of the network on the 1000 test images: 93.017%
Accuracy of the network on the 1000 test images: 92.733%
As we can see we have very good performance for having traing only for 1
epoch! Nevertheless, we are still using structured data Lets see how
epoch! Nevertheless, we are still using structured data... Let's see how
we can build an autoencoder for unstructured data now.
Building a Continuous Convolutional Autoencoder
@@ -470,7 +546,7 @@ Building a Continuous Convolutional Autoencoder
Just as toy problem, we will now build an autoencoder for the following
function :math:`f(x,y)=\sin(\pi x)\sin(\pi y)` on the unit circle domain
centered in :math:`(0.5, 0.5)`. We will also see the ability to
up-sample (once trained) the results without retraining. Lets first
up-sample (once trained) the results without retraining. Let's first
create the input and visualize it, we will use firstly a mesh of
:math:`100` points.
@@ -516,12 +592,12 @@ create the input and visualize it, we will use firstly a mesh of
.. image:: tutorial_files/tutorial_32_0.png
.. image:: output_32_0.png
Lets now build a simple autoencoder using the continuous convolutional
Let's now build a simple autoencoder using the continuous convolutional
filter. The data is clearly unstructured and a simple convolutional
filter might not work without projecting or interpolating first. Lets
filter might not work without projecting or interpolating first. Let's
first build and ``Encoder`` and ``Decoder`` class, and then a
``Autoencoder`` class that contains both.
@@ -532,7 +608,7 @@ first build and ``Encoder`` and ``Decoder`` class, and then a
super().__init__()
# convolutional block
self.convolution = ContinuousConv(input_numb_field=1,
self.convolution = ContinuousConvBlock(input_numb_field=1,
output_numb_field=2,
stride={"domain": [1, 1],
"start": [0, 0],
@@ -542,8 +618,8 @@ first build and ``Encoder`` and ``Decoder`` class, and then a
filter_dim=[0.15, 0.15],
optimize=True)
# feedforward net
self.nn = FeedForward(input_variables=400,
output_variables=hidden_dimension,
self.nn = FeedForward(input_dimensions=400,
output_dimensions=hidden_dimension,
layers=[240, 120])
def forward(self, x):
@@ -558,7 +634,7 @@ first build and ``Encoder`` and ``Decoder`` class, and then a
super().__init__()
# convolutional block
self.convolution = ContinuousConv(input_numb_field=2,
self.convolution = ContinuousConvBlock(input_numb_field=2,
output_numb_field=1,
stride={"domain": [1, 1],
"start": [0, 0],
@@ -568,8 +644,8 @@ first build and ``Encoder`` and ``Decoder`` class, and then a
filter_dim=[0.15, 0.15],
optimize=True)
# feedforward net
self.nn = FeedForward(input_variables=hidden_dimension,
output_variables=400,
self.nn = FeedForward(input_dimensions=hidden_dimension,
output_dimensions=400,
layers=[120, 240])
def forward(self, weights, grid):
@@ -582,7 +658,7 @@ first build and ``Encoder`` and ``Decoder`` class, and then a
Very good! Notice that in the ``Decoder`` class in the ``forward`` pass
we have used the ``.transpose()`` method of the
``ContinuousConvolution`` class. This method accepts the ``weights`` for
upsampling and the ``grid`` on where to upsample. Lets now build the
upsampling and the ``grid`` on where to upsample. Let's now build the
autoencoder! We set the hidden dimension in the ``hidden_dimension``
variable. We apply the sigmoid on the output since the field value is
between :math:`[0, 1]`.
@@ -608,7 +684,7 @@ between :math:`[0, 1]`.
net = Autoencoder()
Lets now train the autoencoder, minimizing the mean square error loss
Let's now train the autoencoder, minimizing the mean square error loss
and optimizing using Adam.
.. code:: ipython3
@@ -640,24 +716,24 @@ and optimizing using Adam.
.. parsed-literal::
epoch [10/150] loss [0.013]
epoch [20/150] loss [0.0029]
epoch [30/150] loss [0.0019]
epoch [10/150] loss [0.012]
epoch [20/150] loss [0.0036]
epoch [30/150] loss [0.0018]
epoch [40/150] loss [0.0014]
epoch [50/150] loss [0.0011]
epoch [60/150] loss [0.00094]
epoch [70/150] loss [0.00082]
epoch [80/150] loss [0.00074]
epoch [90/150] loss [0.00068]
epoch [100/150] loss [0.00064]
epoch [110/150] loss [0.00061]
epoch [120/150] loss [0.00058]
epoch [130/150] loss [0.00057]
epoch [140/150] loss [0.00056]
epoch [150/150] loss [0.00054]
epoch [50/150] loss [0.0012]
epoch [60/150] loss [0.001]
epoch [70/150] loss [0.0009]
epoch [80/150] loss [0.00082]
epoch [90/150] loss [0.00075]
epoch [100/150] loss [0.0007]
epoch [110/150] loss [0.00066]
epoch [120/150] loss [0.00063]
epoch [130/150] loss [0.00061]
epoch [140/150] loss [0.00059]
epoch [150/150] loss [0.00058]
Lets visualize the two solutions side by side!
Let's visualize the two solutions side by side!
.. code:: ipython3
@@ -681,7 +757,7 @@ Lets visualize the two solutions side by side!
.. image:: tutorial_files/tutorial_40_0.png
.. image:: output_40_0.png
As we can see the two are really similar! We can compute the :math:`l_2`
@@ -698,19 +774,19 @@ error quite easily as well:
.. parsed-literal::
l2 error: 4.10%
l2 error: 4.22%
More or less :math:`4\%` in :math:`l_2` error, which is really low
considering the fact that we use just **one** convolutional layer and a
simple feedforward to decrease the dimension. Lets see now some
simple feedforward to decrease the dimension. Let's see now some
peculiarity of the filter.
Filter for upsampling
~~~~~~~~~~~~~~~~~~~~~
Suppose we have already the hidden dimension and we want to upsample on
a differen grid with more points. Lets see how to do it:
a differen grid with more points. Let's see how to do it:
.. code:: ipython3
@@ -744,11 +820,11 @@ a differen grid with more points. Lets see how to do it:
.. image:: tutorial_files/tutorial_45_0.png
.. image:: output_45_0.png
As we can see we have a very good approximation of the original
function, even thought some noise is present. Lets calculate the error
function, even thought some noise is present. Let's calculate the error
now:
.. code:: ipython3
@@ -758,7 +834,7 @@ now:
.. parsed-literal::
l2 error: 8.44%
l2 error: 8.37%
Autoencoding at different resolution
@@ -768,7 +844,7 @@ In the previous example we already had the hidden dimension (of original
input) and we used it to upsample. Sometimes however we have a more fine
mesh solution and we simply want to encode it. This can be done without
retraining! This procedure can be useful in case we have many points in
the mesh and just a smaller part of them are needed for training. Lets
the mesh and just a smaller part of them are needed for training. Let's
see the results of this:
.. code:: ipython3
@@ -807,15 +883,15 @@ see the results of this:
.. image:: tutorial_files/tutorial_49_0.png
.. image:: output_49_0.png
.. parsed-literal::
l2 error: 8.45%
l2 error: 8.50%
Whats next?
What's next?
------------
We have shown the basic usage of a convolutional filter. In the next