Continuous Convolution (#69)
* network handling update * adding tutorial * docs
This commit is contained in:
@@ -8,3 +8,4 @@ Code Documentation
|
||||
FeedForward <fnn.rst>
|
||||
DeepONet <deeponet.rst>
|
||||
PINN <pinn.rst>
|
||||
ContinuousConv <convolution.rst>
|
||||
|
||||
12
docs/source/_rst/convolution.rst
Normal file
12
docs/source/_rst/convolution.rst
Normal file
@@ -0,0 +1,12 @@
|
||||
ContinuousConv
|
||||
==============
|
||||
.. currentmodule:: pina.model.layers.convolution_2d
|
||||
|
||||
.. automodule:: pina.model.layers.convolution_2d
|
||||
|
||||
.. autoclass:: ContinuousConv
|
||||
:members:
|
||||
:private-members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
824
docs/source/_rst/tutorial4/tutorial.rst
Normal file
824
docs/source/_rst/tutorial4/tutorial.rst
Normal file
@@ -0,0 +1,824 @@
|
||||
Tutorial 4: continuous convolutional filter
|
||||
===========================================
|
||||
|
||||
In this tutorial we will show how to use the Continouous Convolutional
|
||||
Filter, and how to build common Deep Learning architectures with it. The
|
||||
implementation of the filter follows the original work `A Continuous
|
||||
Convolutional Trainable Filter for Modelling Unstructured
|
||||
Data <https://arxiv.org/abs/2210.13416>`__ of Coscia Dario, Laura
|
||||
Meneghetti, Nicola Demo, Giovanni Stabile, and Gianluigi Rozza.
|
||||
|
||||
First of all we import the modules needed for the tutorial, which
|
||||
include:
|
||||
|
||||
- ``ContinuousConv`` class from ``pina.model.layers`` which implements
|
||||
the continuous convolutional filter
|
||||
- ``PyTorch`` and ``Matplotlib`` for tensorial operations and
|
||||
visualization respectively
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from pina.model.layers import ContinuousConv
|
||||
import torchvision # for MNIST dataset
|
||||
from pina.model import FeedForward # for building AE and MNIST classification
|
||||
|
||||
The tutorial is structured as follow:
|
||||
|
||||
* `Continuous filter background <#continuous-filter-background>`__: understand how the convolutional filter works and how to use it.
|
||||
|
||||
* `Building a MNIST Classifier <#building-a-mnist-classifier>`__: show how to build a simple classifier using the MNIST dataset and how to combine a continuous convolutional layer with a feedforward neural network.
|
||||
|
||||
* `Building a Continuous Convolutional Autoencoder <#building-a-continuous-convolutional-autoencoder>`__: show how to use the continuous filter to work with unstructured data for autoencoding and up-sampling.
|
||||
|
||||
Continuous filter background
|
||||
----------------------------
|
||||
|
||||
As reported by the authors in the original paper: in contrast to
|
||||
discrete convolution, continuous convolution is mathematically defined
|
||||
as:
|
||||
|
||||
.. math::
|
||||
|
||||
|
||||
\mathcal{I}_{\rm{out}}(\mathbf{x}) = \int_{\mathcal{X}} \mathcal{I}(\mathbf{x} + \mathbf{\tau}) \cdot \mathcal{K}(\mathbf{\tau}) d\mathbf{\tau},
|
||||
|
||||
where :math:`\mathcal{K} : \mathcal{X} \rightarrow \mathbb{R}` is the
|
||||
*continuous filter* function, and
|
||||
:math:`\mathcal{I} : \Omega \subset \mathbb{R}^N \rightarrow \mathbb{R}`
|
||||
is the input function. The continuous filter function is approximated
|
||||
using a FeedForward Neural Network, thus trainable during the training
|
||||
phase. The way in which the integral is approximated can be different,
|
||||
currently on **PINA** we approximate it using a simple sum, as suggested
|
||||
by the authors. Thus, given :math:`\{\mathbf{x}_i\}_{i=1}^{n}` points in
|
||||
:math:`\mathbb{R}^N` of the input function mapped on the
|
||||
:math:`\mathcal{X}` filter domain, we approximate the above equation as:
|
||||
|
||||
.. math::
|
||||
|
||||
|
||||
\mathcal{I}_{\rm{out}}(\mathbf{\tilde{x}}_i) = \sum_{{\mathbf{x}_i}\in\mathcal{X}} \mathcal{I}(\mathbf{x}_i + \mathbf{\tau}) \cdot \mathcal{K}(\mathbf{x}_i),
|
||||
|
||||
where :math:`\mathbf{\tau} \in \mathcal{S}`, with :math:`\mathcal{S}`
|
||||
the set of available strides, corresponds to the current stride position
|
||||
of the filter, and :math:`\mathbf{\tilde{x}}_i` points are obtained by
|
||||
taking the centroid of the filter position mapped on the :math:`\Omega`
|
||||
domain.
|
||||
|
||||
We will now try to pratically see how to work with the filter. From the
|
||||
above definition we see that what is needed is: 1. A domain and a
|
||||
function defined on that domain (the input) 2. A stride, corresponding
|
||||
to the positions where the filter needs to be :math:`\rightarrow`
|
||||
``stride`` variable in ``ContinuousConv`` 3. The filter rectangular
|
||||
domain :math:`\rightarrow` ``filter_dim`` variable in ``ContinuousConv``
|
||||
|
||||
Input function
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
The input function for the continuous filter is defined as a tensor of
|
||||
shape:
|
||||
|
||||
.. math:: [B \times N_{in} \times N \times D]
|
||||
|
||||
\ where :math:`B` is the batch_size, :math:`N_{in}` is the number of
|
||||
input fields, :math:`N` the number of points in the mesh, :math:`D` the
|
||||
dimension of the problem. In particular:
|
||||
|
||||
* :math:`D` is the number of spatial variables + 1. The last column must contain the field value. For example for 2D problems :math:`D=3` and the tensor will be something like ``[first coordinate, second coordinate, field value]``
|
||||
|
||||
* :math:`N_{in}` represents the number of vectorial function presented. For example a vectorial function :math:`f = [f_1, f_2]` will have math:`N_{in}=2`
|
||||
|
||||
Let’s see an example to clear the ideas. We will be verbose to explain
|
||||
in details the input form. We wish to create the function:
|
||||
|
||||
.. math::
|
||||
|
||||
|
||||
f(x, y) = [\sin(\pi x) \sin(\pi y), -\sin(\pi x) \sin(\pi y)] \quad (x,y)\in[0,1]\times[0,1]
|
||||
|
||||
using a batch size of one.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# batch size fixed to 1
|
||||
batch_size = 1
|
||||
|
||||
# points in the mesh fixed to 200
|
||||
N = 200
|
||||
|
||||
# vectorial 2 dimensional function, number_input_fileds=2
|
||||
number_input_fileds = 2
|
||||
|
||||
# 2 dimensional spatial variables, D = 2 + 1 = 3
|
||||
D = 3
|
||||
|
||||
# create the function f domain as random 2d points in [0, 1]
|
||||
domain = torch.rand(size=(batch_size, number_input_fileds, N, D-1))
|
||||
print(f"Domain has shape: {domain.shape}")
|
||||
|
||||
# create the functions
|
||||
pi = torch.acos(torch.tensor([-1.])) # pi value
|
||||
f1 = torch.sin(pi * domain[:, 0, :, 0]) * torch.sin(pi * domain[:, 0, :, 1])
|
||||
f2 = - torch.sin(pi * domain[:, 1, :, 0]) * torch.sin(pi * domain[:, 1, :, 1])
|
||||
|
||||
# stacking the input domain and field values
|
||||
data = torch.empty(size=(batch_size, number_input_fileds, N, D))
|
||||
data[..., :-1] = domain # copy the domain
|
||||
data[:, 0, :, -1] = f1 # copy first field value
|
||||
data[:, 1, :, -1] = f1 # copy second field value
|
||||
print(f"Filter input data has shape: {data.shape}")
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
Domain has shape: torch.Size([1, 2, 200, 2])
|
||||
Filter input data has shape: torch.Size([1, 2, 200, 3])
|
||||
|
||||
|
||||
Stride
|
||||
~~~~~~
|
||||
|
||||
The stride is passed as a dictionary ``stride`` which tells the filter
|
||||
where to go. Here is an example for the :math:`[0,1]\times[0,5]` domain:
|
||||
|
||||
.. code:: python
|
||||
|
||||
# stride definition
|
||||
stride = {"domain": [1, 5],
|
||||
"start": [0, 0],
|
||||
"jump": [0.1, 0.3],
|
||||
"direction": [1, 1],
|
||||
}
|
||||
|
||||
This tells the filter:
|
||||
|
||||
1. ``domain``: square domain (the only implemented) :math:`[0,1]\times[0,5]`. The minimum value is always zero, while the maximum is specified by the user
|
||||
|
||||
2. ``start``: start position of the filter, coordinate :math:`(0, 0)`
|
||||
|
||||
3. ``jump``: the jumps of the centroid of the filter to the next position :math:`(0.1, 0.3)`
|
||||
|
||||
4. ``direction``: the directions of the jump, with ``1 = right``, ``0 = no jump``,\ ``-1 = left`` with respect to the current position
|
||||
|
||||
**Note**
|
||||
|
||||
We are planning to release the possibility to directly pass a list of
|
||||
possible strides!
|
||||
|
||||
Filter definition
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Having defined all the previous blocks we are able to construct the
|
||||
continuous filter.
|
||||
|
||||
Suppose we would like to get an ouput with only one field, and let us
|
||||
fix the filter dimension to be :math:`[0.1, 0.1]`.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# filter dim
|
||||
filter_dim = [0.1, 0.1]
|
||||
|
||||
# stride
|
||||
stride = {"domain": [1, 1],
|
||||
"start": [0, 0],
|
||||
"jump": [0.08, 0.08],
|
||||
"direction": [1, 1],
|
||||
}
|
||||
|
||||
# creating the filter
|
||||
cConv = ContinuousConv(input_numb_field=number_input_fileds,
|
||||
output_numb_field=1,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride)
|
||||
|
||||
That’s it! In just one line of code we have created the continuous
|
||||
convolutional filter. By default the ``pina.model.FeedForward`` neural
|
||||
network is intitialised, more on the
|
||||
`documentation <https://mathlab.github.io/PINA/_rst/fnn.html>`__. In
|
||||
case the mesh doesn’t change during training we can set the ``optimize``
|
||||
flag equals to ``True``, to exploit optimizations for finding the points
|
||||
to convolve.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# creating the filter + optimization
|
||||
cConv = ContinuousConv(input_numb_field=number_input_fileds,
|
||||
output_numb_field=1,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
optimize=True)
|
||||
|
||||
|
||||
Let’s try to do a forward pass
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
print(f"Filter input data has shape: {data.shape}")
|
||||
|
||||
#input to the filter
|
||||
output = cConv(data)
|
||||
|
||||
print(f"Filter output data has shape: {output.shape}")
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
Filter input data has shape: torch.Size([1, 2, 200, 3])
|
||||
Filter output data has shape: torch.Size([1, 1, 169, 3])
|
||||
|
||||
|
||||
If we don’t want to use the default ``FeedForward`` neural network, we
|
||||
can pass a specified torch model in the ``model`` keyword as follow:
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
class SimpleKernel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self. model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
cConv = ContinuousConv(input_numb_field=number_input_fileds,
|
||||
output_numb_field=1,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
optimize=True,
|
||||
model=SimpleKernel)
|
||||
|
||||
|
||||
Notice that we pass the class and not an already built object!
|
||||
|
||||
Building a MNIST Classifier
|
||||
---------------------------
|
||||
|
||||
Let’s see how we can build a MNIST classifier using a continuous
|
||||
convolutional filter. We will use the MNIST dataset from PyTorch. In
|
||||
order to keep small training times we use only 6000 samples for training
|
||||
and 1000 samples for testing.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
from torch.utils.data import DataLoader, SubsetRandomSampler
|
||||
|
||||
numb_training = 6000 # get just 6000 images for training
|
||||
numb_testing= 1000 # get just 1000 images for training
|
||||
seed = 111 # for reproducibility
|
||||
batch_size = 8 # setting batch size
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# downloading the dataset
|
||||
train_data = torchvision.datasets.MNIST('./data/', train=True, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize(
|
||||
(0.1307,), (0.3081,))
|
||||
]))
|
||||
subsample_train_indices = torch.randperm(len(train_data))[:numb_training]
|
||||
train_loader = DataLoader(train_data, batch_size=batch_size,
|
||||
sampler=SubsetRandomSampler(subsample_train_indices))
|
||||
|
||||
test_data = torchvision.datasets.MNIST('./data/', train=False, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize(
|
||||
(0.1307,), (0.3081,))
|
||||
]))
|
||||
subsample_test_indices = torch.randperm(len(train_data))[:numb_testing]
|
||||
test_loader = DataLoader(train_data, batch_size=batch_size,
|
||||
sampler=SubsetRandomSampler(subsample_train_indices))
|
||||
|
||||
Let’s now build a simple classifier. The MNIST dataset is composed by
|
||||
vectors of shape ``[batch, 1, 28, 28]``, but we can image them as one
|
||||
field functions where the pixels :math:`ij` are the coordinate
|
||||
:math:`x=i, y=j` in a :math:`[0, 27]\times[0,27]` domain, and the pixels
|
||||
value are the field values. We just need a function to transform the
|
||||
regular tensor in a tensor compatible for the continuous filter:
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
def transform_input(x):
|
||||
batch_size = x.shape[0]
|
||||
dim_grid = tuple(x.shape[:-3:-1])
|
||||
|
||||
# creating the n dimensional mesh grid for a single channel image
|
||||
values_mesh = [torch.arange(0, dim).float() for dim in dim_grid]
|
||||
mesh = torch.meshgrid(values_mesh)
|
||||
coordinates_mesh = [x.reshape(-1, 1) for x in mesh]
|
||||
coordinates = torch.cat(coordinates_mesh, dim=1).unsqueeze(
|
||||
0).repeat((batch_size, 1, 1)).unsqueeze(1)
|
||||
|
||||
return torch.cat((coordinates, x.flatten(2).unsqueeze(-1)), dim=-1)
|
||||
|
||||
|
||||
# let's try it out
|
||||
image, s = next(iter(train_loader))
|
||||
print(f"Original MNIST image shape: {image.shape}")
|
||||
|
||||
image_transformed = transform_input(image)
|
||||
print(f"Transformed MNIST image shape: {image_transformed.shape}")
|
||||
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
Original MNIST image shape: torch.Size([8, 1, 28, 28])
|
||||
Transformed MNIST image shape: torch.Size([8, 1, 784, 3])
|
||||
|
||||
|
||||
We can now build a simple classifier! We will use just one convolutional
|
||||
filter followed by a feedforward neural network
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
class ContinuousClassifier(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# number of classes for classification
|
||||
numb_class = 10
|
||||
|
||||
# convolutional block
|
||||
self.convolution = ContinuousConv(input_numb_field=1,
|
||||
output_numb_field=4,
|
||||
stride={"domain": [27, 27],
|
||||
"start": [0, 0],
|
||||
"jumps": [4, 4],
|
||||
"direction": [1, 1.],
|
||||
},
|
||||
filter_dim=[4, 4],
|
||||
optimize=True)
|
||||
# feedforward net
|
||||
self.nn = FeedForward(input_variables=196,
|
||||
output_variables=numb_class,
|
||||
layers=[120, 64],
|
||||
func=torch.nn.ReLU)
|
||||
|
||||
def forward(self, x):
|
||||
# transform input + convolution
|
||||
x = transform_input(x)
|
||||
x = self.convolution(x)
|
||||
# feed forward classification
|
||||
return self.nn(x[..., -1].flatten(1))
|
||||
|
||||
|
||||
net = ContinuousClassifier()
|
||||
|
||||
Let’s try to train it using a simple pytorch training loop. We train for
|
||||
juts 1 epoch using Adam optimizer with a :math:`0.001` learning rate.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# optimizer and loss function
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
for epoch in range(1): # loop over the dataset multiple times
|
||||
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
# get the inputs; data is a list of [inputs, labels]
|
||||
inputs, labels = data
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
running_loss += loss.item()
|
||||
if i % 50 == 49:
|
||||
print(
|
||||
f'epoch [{i + 1}/{numb_training//batch_size}] loss[{running_loss / 500:.3f}]')
|
||||
running_loss = 0.0
|
||||
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
epoch [50/750] loss[0.148]
|
||||
epoch [100/750] loss[0.072]
|
||||
epoch [150/750] loss[0.063]
|
||||
epoch [200/750] loss[0.053]
|
||||
epoch [250/750] loss[0.041]
|
||||
epoch [300/750] loss[0.048]
|
||||
epoch [350/750] loss[0.054]
|
||||
epoch [400/750] loss[0.048]
|
||||
epoch [450/750] loss[0.047]
|
||||
epoch [500/750] loss[0.035]
|
||||
epoch [550/750] loss[0.036]
|
||||
epoch [600/750] loss[0.041]
|
||||
epoch [650/750] loss[0.030]
|
||||
epoch [700/750] loss[0.040]
|
||||
epoch [750/750] loss[0.040]
|
||||
|
||||
|
||||
Let’s see the performance on the train set!
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for data in test_loader:
|
||||
images, labels = data
|
||||
# calculate outputs by running images through the network
|
||||
outputs = net(images)
|
||||
# the class with the highest energy is what we choose as prediction
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
print(
|
||||
f'Accuracy of the network on the 1000 test images: {(correct / total):.3%}')
|
||||
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
Accuracy of the network on the 1000 test images: 93.017%
|
||||
|
||||
|
||||
As we can see we have very good performance for having traing only for 1
|
||||
epoch! Nevertheless, we are still using structured data… Let’s see how
|
||||
we can build an autoencoder for unstructured data now.
|
||||
|
||||
Building a Continuous Convolutional Autoencoder
|
||||
-----------------------------------------------
|
||||
|
||||
Just as toy problem, we will now build an autoencoder for the following
|
||||
function :math:`f(x,y)=\sin(\pi x)\sin(\pi y)` on the unit circle domain
|
||||
centered in :math:`(0.5, 0.5)`. We will also see the ability to
|
||||
up-sample (once trained) the results without retraining. Let’s first
|
||||
create the input and visualize it, we will use firstly a mesh of
|
||||
:math:`100` points.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# create inputs
|
||||
def circle_grid(N=100):
|
||||
"""Generate points withing a unit 2D circle centered in (0.5, 0.5)
|
||||
|
||||
:param N: number of points
|
||||
:type N: float
|
||||
:return: [x, y] array of points
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
|
||||
PI = torch.acos(torch.zeros(1)).item() * 2
|
||||
R = 0.5
|
||||
centerX = 0.5
|
||||
centerY = 0.5
|
||||
|
||||
r = R * torch.sqrt(torch.rand(N))
|
||||
theta = torch.rand(N) * 2 * PI
|
||||
|
||||
x = centerX + r * torch.cos(theta)
|
||||
y = centerY + r * torch.sin(theta)
|
||||
|
||||
return torch.stack([x, y]).T
|
||||
|
||||
# create the grid
|
||||
grid = circle_grid(500)
|
||||
|
||||
# create input
|
||||
input_data = torch.empty(size=(1, 1, grid.shape[0], 3))
|
||||
input_data[0, 0, :, :-1] = grid
|
||||
input_data[0, 0, :, -1] = torch.sin(pi * grid[:, 0]) * torch.sin(pi * grid[:, 1])
|
||||
|
||||
# visualize data
|
||||
plt.title("Training sample with 500 points")
|
||||
plt.scatter(grid[:, 0], grid[:, 1], c=input_data[0, 0, :, -1])
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
.. image:: tutorial_files/tutorial_32_0.png
|
||||
|
||||
|
||||
Let’s now build a simple autoencoder using the continuous convolutional
|
||||
filter. The data is clearly unstructured and a simple convolutional
|
||||
filter might not work without projecting or interpolating first. Let’s
|
||||
first build and ``Encoder`` and ``Decoder`` class, and then a
|
||||
``Autoencoder`` class that contains both.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, hidden_dimension):
|
||||
super().__init__()
|
||||
|
||||
# convolutional block
|
||||
self.convolution = ContinuousConv(input_numb_field=1,
|
||||
output_numb_field=2,
|
||||
stride={"domain": [1, 1],
|
||||
"start": [0, 0],
|
||||
"jumps": [0.05, 0.05],
|
||||
"direction": [1, 1.],
|
||||
},
|
||||
filter_dim=[0.15, 0.15],
|
||||
optimize=True)
|
||||
# feedforward net
|
||||
self.nn = FeedForward(input_variables=400,
|
||||
output_variables=hidden_dimension,
|
||||
layers=[240, 120])
|
||||
|
||||
def forward(self, x):
|
||||
# convolution
|
||||
x = self.convolution(x)
|
||||
# feed forward pass
|
||||
return self.nn(x[..., -1])
|
||||
|
||||
|
||||
class Decoder(torch.nn.Module):
|
||||
def __init__(self, hidden_dimension):
|
||||
super().__init__()
|
||||
|
||||
# convolutional block
|
||||
self.convolution = ContinuousConv(input_numb_field=2,
|
||||
output_numb_field=1,
|
||||
stride={"domain": [1, 1],
|
||||
"start": [0, 0],
|
||||
"jumps": [0.05, 0.05],
|
||||
"direction": [1, 1.],
|
||||
},
|
||||
filter_dim=[0.15, 0.15],
|
||||
optimize=True)
|
||||
# feedforward net
|
||||
self.nn = FeedForward(input_variables=hidden_dimension,
|
||||
output_variables=400,
|
||||
layers=[120, 240])
|
||||
|
||||
def forward(self, weights, grid):
|
||||
# feed forward pass
|
||||
x = self.nn(weights)
|
||||
# transpose convolution
|
||||
return torch.sigmoid(self.convolution.transpose(x, grid))
|
||||
|
||||
|
||||
Very good! Notice that in the ``Decoder`` class in the ``forward`` pass
|
||||
we have used the ``.transpose()`` method of the
|
||||
``ContinuousConvolution`` class. This method accepts the ``weights`` for
|
||||
upsampling and the ``grid`` on where to upsample. Let’s now build the
|
||||
autoencoder! We set the hidden dimension in the ``hidden_dimension``
|
||||
variable. We apply the sigmoid on the output since the field value is
|
||||
between :math:`[0, 1]`.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
class Autoencoder(torch.nn.Module):
|
||||
def __init__(self, hidden_dimension=10):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(hidden_dimension)
|
||||
self.decoder = Decoder(hidden_dimension)
|
||||
|
||||
def forward(self, x):
|
||||
# saving grid for later upsampling
|
||||
grid = x.clone().detach()
|
||||
# encoder
|
||||
weights = self.encoder(x)
|
||||
# decoder
|
||||
out = self.decoder(weights, grid)
|
||||
return out
|
||||
|
||||
|
||||
net = Autoencoder()
|
||||
|
||||
Let’s now train the autoencoder, minimizing the mean square error loss
|
||||
and optimizing using Adam.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# optimizer and loss function
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
|
||||
criterion = torch.nn.MSELoss()
|
||||
max_epochs = 150
|
||||
|
||||
for epoch in range(max_epochs): # loop over the dataset multiple times
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
outputs = net(input_data)
|
||||
loss = criterion(outputs[..., -1], input_data[..., -1])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
if epoch % 10 ==9:
|
||||
print(f'epoch [{epoch + 1}/{max_epochs}] loss [{loss.item():.2}]')
|
||||
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
epoch [10/150] loss [0.013]
|
||||
epoch [20/150] loss [0.0029]
|
||||
epoch [30/150] loss [0.0019]
|
||||
epoch [40/150] loss [0.0014]
|
||||
epoch [50/150] loss [0.0011]
|
||||
epoch [60/150] loss [0.00094]
|
||||
epoch [70/150] loss [0.00082]
|
||||
epoch [80/150] loss [0.00074]
|
||||
epoch [90/150] loss [0.00068]
|
||||
epoch [100/150] loss [0.00064]
|
||||
epoch [110/150] loss [0.00061]
|
||||
epoch [120/150] loss [0.00058]
|
||||
epoch [130/150] loss [0.00057]
|
||||
epoch [140/150] loss [0.00056]
|
||||
epoch [150/150] loss [0.00054]
|
||||
|
||||
|
||||
Let’s visualize the two solutions side by side!
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
net.eval()
|
||||
|
||||
# get output and detach from computational graph for plotting
|
||||
output = net(input_data).detach()
|
||||
|
||||
# visualize data
|
||||
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
|
||||
pic1 = axes[0].scatter(grid[:, 0], grid[:, 1], c=input_data[0, 0, :, -1])
|
||||
axes[0].set_title("Real")
|
||||
fig.colorbar(pic1)
|
||||
plt.subplot(1, 2, 2)
|
||||
pic2 = axes[1].scatter(grid[:, 0], grid[:, 1], c=output[0, 0, :, -1])
|
||||
axes[1].set_title("Autoencoder")
|
||||
fig.colorbar(pic2)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
.. image:: tutorial_files/tutorial_40_0.png
|
||||
|
||||
|
||||
As we can see the two are really similar! We can compute the :math:`l_2`
|
||||
error quite easily as well:
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
def l2_error(input_, target):
|
||||
return torch.linalg.norm(input_-target, ord=2)/torch.linalg.norm(input_, ord=2)
|
||||
|
||||
|
||||
print(f'l2 error: {l2_error(input_data[0, 0, :, -1], output[0, 0, :, -1]):.2%}')
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
l2 error: 4.10%
|
||||
|
||||
|
||||
More or less :math:`4\%` in :math:`l_2` error, which is really low
|
||||
considering the fact that we use just **one** convolutional layer and a
|
||||
simple feedforward to decrease the dimension. Let’s see now some
|
||||
peculiarity of the filter.
|
||||
|
||||
Filter for upsampling
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Suppose we have already the hidden dimension and we want to upsample on
|
||||
a differen grid with more points. Let’s see how to do it:
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
grid2 = circle_grid(1500) # triple number of points
|
||||
input_data2 = torch.zeros(size=(1, 1, grid2.shape[0], 3))
|
||||
input_data2[0, 0, :, :-1] = grid2
|
||||
input_data2[0, 0, :, -1] = torch.sin(pi *
|
||||
grid2[:, 0]) * torch.sin(pi * grid2[:, 1])
|
||||
|
||||
# get the hidden dimension representation from original input
|
||||
latent = net.encoder(input_data)
|
||||
|
||||
# upsample on the second input_data2
|
||||
output = net.decoder(latent, input_data2).detach()
|
||||
|
||||
# show the picture
|
||||
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
|
||||
pic1 = axes[0].scatter(grid2[:, 0], grid2[:, 1], c=input_data2[0, 0, :, -1])
|
||||
axes[0].set_title("Real")
|
||||
fig.colorbar(pic1)
|
||||
plt.subplot(1, 2, 2)
|
||||
pic2 = axes[1].scatter(grid2[:, 0], grid2[:, 1], c=output[0, 0, :, -1])
|
||||
axes[1].set_title("Up-sampling")
|
||||
fig.colorbar(pic2)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
.. image:: tutorial_files/tutorial_45_0.png
|
||||
|
||||
|
||||
As we can see we have a very good approximation of the original
|
||||
function, even thought some noise is present. Let’s calculate the error
|
||||
now:
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
print(f'l2 error: {l2_error(input_data2[0, 0, :, -1], output[0, 0, :, -1]):.2%}')
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
l2 error: 8.44%
|
||||
|
||||
|
||||
Autoencoding at different resolution
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In the previous example we already had the hidden dimension (of original
|
||||
input) and we used it to upsample. Sometimes however we have a more fine
|
||||
mesh solution and we simply want to encode it. This can be done without
|
||||
retraining! This procedure can be useful in case we have many points in
|
||||
the mesh and just a smaller part of them are needed for training. Let’s
|
||||
see the results of this:
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
grid2 = circle_grid(3500) # very fine mesh
|
||||
input_data2 = torch.zeros(size=(1, 1, grid2.shape[0], 3))
|
||||
input_data2[0, 0, :, :-1] = grid2
|
||||
input_data2[0, 0, :, -1] = torch.sin(pi *
|
||||
grid2[:, 0]) * torch.sin(pi * grid2[:, 1])
|
||||
|
||||
# get the hidden dimension representation from more fine mesh input
|
||||
latent = net.encoder(input_data2)
|
||||
|
||||
# upsample on the second input_data2
|
||||
output = net.decoder(latent, input_data2).detach()
|
||||
|
||||
# show the picture
|
||||
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
|
||||
pic1 = axes[0].scatter(grid2[:, 0], grid2[:, 1], c=input_data2[0, 0, :, -1])
|
||||
axes[0].set_title("Real")
|
||||
fig.colorbar(pic1)
|
||||
plt.subplot(1, 2, 2)
|
||||
pic2 = axes[1].scatter(grid2[:, 0], grid2[:, 1], c=output[0, 0, :, -1])
|
||||
axes[1].set_title("Autoencoder not re-trained")
|
||||
fig.colorbar(pic2)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# calculate l2 error
|
||||
print(
|
||||
f'l2 error: {l2_error(input_data2[0, 0, :, -1], output[0, 0, :, -1]):.2%}')
|
||||
|
||||
|
||||
|
||||
|
||||
.. image:: tutorial_files/tutorial_49_0.png
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
l2 error: 8.45%
|
||||
|
||||
|
||||
What’s next?
|
||||
------------
|
||||
|
||||
We have shown the basic usage of a convolutional filter. In the next
|
||||
tutorials we will show how to combine the PINA framework with the
|
||||
convolutional filter to train in few lines and efficiently a Neural
|
||||
Network!
|
||||
BIN
docs/source/_rst/tutorial4/tutorial_files/tutorial_32_0.png
Normal file
BIN
docs/source/_rst/tutorial4/tutorial_files/tutorial_32_0.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 61 KiB |
BIN
docs/source/_rst/tutorial4/tutorial_files/tutorial_40_0.png
Normal file
BIN
docs/source/_rst/tutorial4/tutorial_files/tutorial_40_0.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
BIN
docs/source/_rst/tutorial4/tutorial_files/tutorial_45_0.png
Normal file
BIN
docs/source/_rst/tutorial4/tutorial_files/tutorial_45_0.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 KiB |
BIN
docs/source/_rst/tutorial4/tutorial_files/tutorial_49_0.png
Normal file
BIN
docs/source/_rst/tutorial4/tutorial_files/tutorial_49_0.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 72 KiB |
@@ -38,6 +38,7 @@ solve problems in a continuous and nonlinear settings.
|
||||
Getting start with PINA <_rst/tutorial1/tutorial.rst>
|
||||
Poisson problem <_rst/tutorial2/tutorial.rst>
|
||||
Wave equation <_rst/tutorial3/tutorial.rst>
|
||||
Continuous Convolutional Filter <_rst/tutorial4/tutorial.rst>
|
||||
|
||||
.. ........................................................................................
|
||||
|
||||
|
||||
7
pina/model/layers/__init__.py
Normal file
7
pina/model/layers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
__all__ = [
|
||||
'BaseContinuousConv',
|
||||
'ContinuousConv'
|
||||
]
|
||||
|
||||
from .convolution import BaseContinuousConv
|
||||
from .convolution_2d import ContinuousConv
|
||||
154
pina/model/layers/convolution.py
Normal file
154
pina/model/layers/convolution.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Module for Base Continuous Convolution class."""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from .stride import Stride
|
||||
from .utils_convolution import optimizing
|
||||
|
||||
|
||||
class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract class
|
||||
"""
|
||||
|
||||
def __init__(self, input_numb_field, output_numb_field,
|
||||
filter_dim, stride, model=None, optimize=False,
|
||||
no_overlap=False):
|
||||
"""Base Class for Continuous Convolution.
|
||||
|
||||
The algorithm expects input to be in the form:
|
||||
$$[B \times N_{in} \times N \times D]$$
|
||||
where $B$ is the batch_size, $N_{in}$ is the number of input
|
||||
fields, $N$ the number of points in the mesh, $D$ the dimension
|
||||
of the problem. In particular:
|
||||
* $D$ is the number of spatial variables + 1. The last column must
|
||||
contain the field value. For example for 2D problems $D=3$ and
|
||||
the tensor will be something like `[first coordinate, second
|
||||
coordinate, field value]`.
|
||||
* $N_{in}$ represents the number of vectorial function presented.
|
||||
For example a vectorial function $f = [f_1, f_2]$ will have
|
||||
$N_{in}=2$.
|
||||
|
||||
:Note
|
||||
A 2-dimensional vectorial function $N_{in}=2$ of 3-dimensional
|
||||
input $D=3+1=4$ with 100 points input mesh and batch size of 8
|
||||
is represented as a tensor `[8, 2, 100, 4]`, where the columns
|
||||
`[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
|
||||
second filed value respectively
|
||||
|
||||
The algorithm returns a tensor of shape:
|
||||
$$[B \times N_{out} \times N' \times D]$$
|
||||
where $B$ is the batch_size, $N_{out}$ is the number of output
|
||||
fields, $N'$ the number of points in the mesh, $D$ the dimension
|
||||
of the problem.
|
||||
|
||||
:param input_numb_field: number of fields in the input
|
||||
:type input_numb_field: int
|
||||
:param output_numb_field: number of fields in the output
|
||||
:type output_numb_field: int
|
||||
:param filter_dim: dimension of the filter
|
||||
:type filter_dim: tuple/ list
|
||||
:param stride: stride for the filter
|
||||
:type stride: dict
|
||||
:param model: neural network for inner parametrization,
|
||||
defaults to None
|
||||
:type model: torch.nn.Module, optional
|
||||
:param optimize: flag for performing optimization on the continuous
|
||||
filter, defaults to False. The flag `optimize=True` should be
|
||||
used only when the scatter datapoints are fixed through the
|
||||
training. If torch model is in `.eval()` mode, the flag is
|
||||
automatically set to False always.
|
||||
:type optimize: bool, optional
|
||||
:param no_overlap: flag for performing optimization on the transpose
|
||||
continuous filter, defaults to False. The flag set to `True` should
|
||||
be used only when the filter positions do not overlap for different
|
||||
strides. RuntimeError will raise in case of non-compatible strides.
|
||||
:type no_overlap: bool, optional
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(input_numb_field, int):
|
||||
self._input_numb_field = input_numb_field
|
||||
else:
|
||||
raise ValueError('input_numb_field must be int.')
|
||||
|
||||
if isinstance(output_numb_field, int):
|
||||
self._output_numb_field = output_numb_field
|
||||
else:
|
||||
raise ValueError('input_numb_field must be int.')
|
||||
|
||||
if isinstance(filter_dim, (tuple, list)):
|
||||
vect = filter_dim
|
||||
else:
|
||||
raise ValueError('filter_dim must be tuple or list.')
|
||||
vect = torch.tensor(vect)
|
||||
self.register_buffer("_dim", vect, persistent=False)
|
||||
|
||||
if isinstance(stride, dict):
|
||||
self._stride = Stride(stride)
|
||||
else:
|
||||
raise ValueError('stride must be dictionary.')
|
||||
|
||||
self._net = model
|
||||
|
||||
if isinstance(optimize, bool):
|
||||
self._optimize = optimize
|
||||
else:
|
||||
raise ValueError('optimize must be bool.')
|
||||
|
||||
# choosing how to initialize based on optimization
|
||||
if self._optimize:
|
||||
# optimizing decorator ensure the function is called
|
||||
# just once
|
||||
self._choose_initialization = optimizing(
|
||||
self._initialize_convolution)
|
||||
else:
|
||||
self._choose_initialization = self._initialize_convolution
|
||||
|
||||
if not isinstance(no_overlap, bool):
|
||||
raise ValueError('no_overlap must be bool.')
|
||||
|
||||
if no_overlap:
|
||||
raise NotImplementedError
|
||||
self.transpose = self.transpose_no_overlap
|
||||
else:
|
||||
self.transpose = self.transpose_overlap
|
||||
|
||||
@ property
|
||||
def net(self):
|
||||
return self._net
|
||||
|
||||
@ property
|
||||
def stride(self):
|
||||
return self._stride
|
||||
|
||||
@ property
|
||||
def dim(self):
|
||||
return self._dim
|
||||
|
||||
@ property
|
||||
def input_numb_field(self):
|
||||
return self._input_numb_field
|
||||
|
||||
@ property
|
||||
def output_numb_field(self):
|
||||
return self._output_numb_field
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def forward(self, X):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def transpose_overlap(self, X):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def transpose_no_overlap(self, X):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _initialize_convolution(self, X, type):
|
||||
pass
|
||||
548
pina/model/layers/convolution_2d.py
Normal file
548
pina/model/layers/convolution_2d.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""Module for Continuous Convolution class"""
|
||||
from .convolution import BaseContinuousConv
|
||||
from .utils_convolution import check_point, map_points_
|
||||
from .integral import Integral
|
||||
from ..feed_forward import FeedForward
|
||||
import torch
|
||||
|
||||
|
||||
class ContinuousConv(BaseContinuousConv):
|
||||
"""
|
||||
Implementation of Continuous Convolutional operator.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Coscia, D., Meneghetti, L., Demo, N.,
|
||||
Stabile, G., & Rozza, G.. (2022). A Continuous Convolutional Trainable
|
||||
Filter for Modelling Unstructured Data.
|
||||
DOI: `10.48550/arXiv.2210.13416
|
||||
<https://doi.org/10.48550/arXiv.2210.13416>`_.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_numb_field, output_numb_field,
|
||||
filter_dim, stride, model=None, optimize=False,
|
||||
no_overlap=False):
|
||||
"""
|
||||
|
||||
:param input_numb_field: Number of fields N_in in the input.
|
||||
:type input_numb_field: int
|
||||
:param output_numb_field: Number of fields N_out in the output.
|
||||
:type output_numb_field: int
|
||||
:param filter_dim: Dimension of the filter.
|
||||
:type filter_dim: tuple/ list
|
||||
:param stride: Stride for the filter.
|
||||
:type stride: dict
|
||||
:param model: Neural network for inner parametrization,
|
||||
defaults to None. If None, pina.FeedForward is used, more
|
||||
on https://mathlab.github.io/PINA/_rst/fnn.html.
|
||||
:type model: torch.nn.Module, optional
|
||||
:param optimize: Flag for performing optimization on the continuous
|
||||
filter, defaults to False. The flag `optimize=True` should be
|
||||
used only when the scatter datapoints are fixed through the
|
||||
training. If torch model is in `.eval()` mode, the flag is
|
||||
automatically set to False always.
|
||||
:type optimize: bool, optional
|
||||
:param no_overlap: Flag for performing optimization on the transpose
|
||||
continuous filter, defaults to False. The flag set to `True` should
|
||||
be used only when the filter positions do not overlap for different
|
||||
strides. RuntimeError will raise in case of non-compatible strides.
|
||||
:type no_overlap: bool, optional
|
||||
|
||||
.. note::
|
||||
Using `optimize=True` the filter can be use either in `forward`
|
||||
or in `transpose` mode, not both. If `optimize=False` the same
|
||||
filter can be used for both `transpose` and `forward` modes.
|
||||
|
||||
.. warning::
|
||||
The algorithm expects input to be in the form: [B x N_in x N x D]
|
||||
where B is the batch_size, N_in is the number of input
|
||||
fields, N the number of points in the mesh, D the dimension
|
||||
of the problem. In particular:
|
||||
|
||||
* D is the number of spatial variables + 1. The last column must
|
||||
contain the field value. For example for 2D problems D=3 and
|
||||
the tensor will be something like `[first coordinate, second
|
||||
coordinate, field value]`.
|
||||
|
||||
* N_in represents the number of vectorial function presented.
|
||||
For example a vectorial function f = [f_1, f_2] will have
|
||||
N_in=2.
|
||||
|
||||
The algorithm returns a tensor of shape: [B x N_out x N x D]
|
||||
where B is the batch_size, N_out is the number of output
|
||||
fields, N' the number of points in the mesh, D the dimension
|
||||
of the problem (coordinates + field value).
|
||||
|
||||
For example, a 2-dimensional vectorial function N_in=2 of
|
||||
3-dimensionalcinput D=3+1=4 with 100 points input mesh and batch
|
||||
size of 8 is represented as a tensor `[8, 2, 100, 4]`, where the
|
||||
columnsc`[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
|
||||
second filed value respectively.
|
||||
|
||||
:Example:
|
||||
>>> class MLP(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self. model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 8),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(8, 8),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(8, 1))
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
>>> dim = [3, 3]
|
||||
>>> stride = {"domain": [10, 10],
|
||||
"start": [0, 0],
|
||||
"jumps": [3, 3],
|
||||
"direction": [1, 1.]}
|
||||
>>> conv = ContinuousConv2D(1, 2, dim, stride, MLP)
|
||||
>>> conv
|
||||
ContinuousConv2D(
|
||||
(_net): ModuleList(
|
||||
(0): MLP(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=2, out_features=8, bias=True)
|
||||
(1): ReLU()
|
||||
(2): Linear(in_features=8, out_features=8, bias=True)
|
||||
(3): ReLU()
|
||||
(4): Linear(in_features=8, out_features=1, bias=True)
|
||||
)
|
||||
)
|
||||
(1): MLP(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=2, out_features=8, bias=True)
|
||||
(1): ReLU()
|
||||
(2): Linear(in_features=8, out_features=8, bias=True)
|
||||
(3): ReLU()
|
||||
(4): Linear(in_features=8, out_features=1, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
super().__init__(input_numb_field=input_numb_field,
|
||||
output_numb_field=output_numb_field,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
model=model,
|
||||
optimize=optimize,
|
||||
no_overlap=no_overlap)
|
||||
|
||||
# integral routine
|
||||
self._integral = Integral('discrete')
|
||||
|
||||
# create the network
|
||||
self._net = self._spawn_networks(model)
|
||||
|
||||
# stride for continuous convolution overridden
|
||||
self._stride = self._stride._stride_discrete
|
||||
|
||||
def _spawn_networks(self, model):
|
||||
"""Private method to create a collection of kernels
|
||||
|
||||
:param model: a torch.nn.Module model in form of Object class
|
||||
:type model: torch.nn.Module
|
||||
:return: list of torch.nn.Module models
|
||||
:rtype: torch.nn.ModuleList
|
||||
|
||||
"""
|
||||
nets = []
|
||||
if self._net is None:
|
||||
for _ in range(self._input_numb_field * self._output_numb_field):
|
||||
tmp = FeedForward(len(self._dim), 1)
|
||||
nets.append(tmp)
|
||||
else:
|
||||
if not isinstance(model, object):
|
||||
raise ValueError("Expected a python class inheriting"
|
||||
" from torch.nn.Module")
|
||||
|
||||
for _ in range(self._input_numb_field * self._output_numb_field):
|
||||
tmp = model()
|
||||
if not isinstance(tmp, torch.nn.Module):
|
||||
raise ValueError("The python class must be inherited from"
|
||||
" torch.nn.Module. See the docstring for"
|
||||
" an example.")
|
||||
nets.append(tmp)
|
||||
|
||||
return torch.nn.ModuleList(nets)
|
||||
|
||||
def _extract_mapped_points(self, batch_idx, index, x):
|
||||
"""Priviate method to extract mapped points in the filter
|
||||
|
||||
:param x: input tensor [channel x N x dim]
|
||||
:type x: torch.tensor
|
||||
:return: mapped points and indeces for each channel
|
||||
:rtype: tuple(torch.tensor, list)
|
||||
|
||||
"""
|
||||
mapped_points = []
|
||||
indeces_channels = []
|
||||
|
||||
for stride_idx, current_stride in enumerate(self._stride):
|
||||
|
||||
# indeces of points falling into filter range
|
||||
indeces = index[stride_idx][batch_idx]
|
||||
|
||||
# how many points for each channel fall into the filter?
|
||||
numb_points_insiede = torch.sum(indeces, dim=-1).tolist()
|
||||
|
||||
# extracting points for each channel
|
||||
# shape: [sum(numb_points_insiede), filter_dim + 1]
|
||||
point_stride = x[indeces]
|
||||
|
||||
# mapping points in filter domain
|
||||
map_points_(point_stride[..., :-1], current_stride)
|
||||
|
||||
# extracting points for each channel
|
||||
point_stride_channel = point_stride.split(numb_points_insiede)
|
||||
|
||||
# appending in list for later use
|
||||
mapped_points.append(point_stride_channel)
|
||||
indeces_channels.append(numb_points_insiede)
|
||||
|
||||
# stacking input for passing to neural net
|
||||
mapping = map(torch.cat, zip(*mapped_points))
|
||||
stacked_input = tuple(mapping)
|
||||
indeces_channels = tuple(zip(*indeces_channels))
|
||||
|
||||
return stacked_input, indeces_channels
|
||||
|
||||
def _find_index(self, X):
|
||||
"""Private method to extract indeces for convolution.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
|
||||
"""
|
||||
# append the index for each stride
|
||||
index = []
|
||||
for _, current_stride in enumerate(self._stride):
|
||||
|
||||
tmp = check_point(X, current_stride, self._dim)
|
||||
index.append(tmp)
|
||||
|
||||
# storing the index
|
||||
self._index = index
|
||||
|
||||
def _make_grid_forward(self, X):
|
||||
"""Private method to create forward convolution grid.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
|
||||
"""
|
||||
# filter dimension + number of points in output grid
|
||||
filter_dim = len(self._dim)
|
||||
number_points = len(self._stride)
|
||||
|
||||
# initialize the grid
|
||||
grid = torch.zeros(size=(X.shape[0],
|
||||
self._output_numb_field,
|
||||
number_points,
|
||||
filter_dim + 1),
|
||||
device=X.device,
|
||||
dtype=X.dtype)
|
||||
grid[..., :-1] = (self._stride + self._dim * 0.5)
|
||||
|
||||
# saving the grid
|
||||
self._grid = grid.detach()
|
||||
|
||||
def _make_grid_transpose(self, X):
|
||||
"""Private method to create transpose convolution grid.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
|
||||
"""
|
||||
# initialize to all zeros
|
||||
tmp = torch.zeros_like(X)
|
||||
tmp[..., :-1] = X[..., :-1]
|
||||
|
||||
# save on tmp
|
||||
self._grid_transpose = tmp
|
||||
|
||||
def _make_grid(self, X, type):
|
||||
"""Private method to create convolution grid.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
:param type: type of convolution, ['forward', 'inverse'] the
|
||||
possibilities
|
||||
:type type: string
|
||||
|
||||
"""
|
||||
# choose the type of convolution
|
||||
if type == 'forward':
|
||||
return self._make_grid_forward(X)
|
||||
elif type == 'inverse':
|
||||
self._make_grid_transpose(X)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def _initialize_convolution(self, X, type='forward'):
|
||||
"""Private method to intialize the convolution.
|
||||
The convolution is initialized by setting a grid and
|
||||
calculate the index for finding the points inside the
|
||||
filter.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
:param type: type of convolution, ['forward', 'inverse'] the
|
||||
possibilities
|
||||
:type type: string
|
||||
"""
|
||||
|
||||
# variable for the convolution
|
||||
self._make_grid(X, type)
|
||||
|
||||
# calculate the index
|
||||
self._find_index(X)
|
||||
|
||||
def forward(self, X):
|
||||
"""Forward pass in the layer
|
||||
|
||||
:param x: input data (input_numb_field x N x filter_dim)
|
||||
:type x: torch.tensor
|
||||
:return: feed forward convolution (output_numb_field x N x filter_dim)
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='forward')
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'forward')
|
||||
|
||||
# create convolutional array
|
||||
conv = self._grid.clone().detach()
|
||||
|
||||
# total number of fields
|
||||
tot_dim = self._output_numb_field * self._input_numb_field
|
||||
|
||||
for batch_idx, x in enumerate(X):
|
||||
|
||||
# extract mapped points
|
||||
stacked_input, indeces_channels = self._extract_mapped_points(
|
||||
batch_idx, self._index, x)
|
||||
|
||||
# compute the convolution
|
||||
|
||||
# storing intermidiate results for each channel convolution
|
||||
res_tmp = []
|
||||
# for each field
|
||||
for idx_conv in range(tot_dim):
|
||||
# index for each input field
|
||||
idx = idx_conv % self._input_numb_field
|
||||
# extract input for each channel
|
||||
single_channel_input = stacked_input[idx]
|
||||
# extract filter
|
||||
net = self._net[idx_conv]
|
||||
# calculate filter value
|
||||
staked_output = net(single_channel_input[..., :-1])
|
||||
# perform integral for all strides in one field
|
||||
integral = self._integral(staked_output,
|
||||
single_channel_input[..., -1],
|
||||
indeces_channels[idx])
|
||||
res_tmp.append(integral)
|
||||
|
||||
# stacking integral results
|
||||
res_tmp = torch.stack(res_tmp)
|
||||
|
||||
# sum filters (for each input fields) in groups
|
||||
# for different ouput fields
|
||||
conv[batch_idx, ..., -1] = res_tmp.reshape(self._output_numb_field,
|
||||
self._input_numb_field,
|
||||
-1).sum(1)
|
||||
return conv
|
||||
|
||||
def transpose_no_overlap(self, integrals, X):
|
||||
"""Transpose pass in the layer for no-overlapping filters
|
||||
|
||||
:param integrals: Weights for the transpose convolution. Shape
|
||||
[B x N_in x N]
|
||||
where B is the batch_size, N_in is the number of input
|
||||
fields, N the number of points in the mesh, D the dimension
|
||||
of the problem.
|
||||
:type integral: torch.tensor
|
||||
:param X: Input data. Expect tensor of shape
|
||||
[B x N_in x M x D] where B is the batch_size,
|
||||
N_in is the number of input fields, M the number of points
|
||||
in the mesh, D the dimension of the problem. Note, last column
|
||||
:type X: torch.tensor
|
||||
:return: Feed forward transpose convolution. Tensor of shape
|
||||
[B x N_out x N] where B is the batch_size,
|
||||
N_out is the number of output fields, N the number of points
|
||||
in the mesh, D the dimension of the problem.
|
||||
:rtype: torch.tensor
|
||||
|
||||
.. note::
|
||||
This function is automatically called when `.transpose()`
|
||||
method is used and `no_overlap=True`
|
||||
"""
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='inverse')
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'inverse')
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
conv_transposed = self._grid_transpose.clone().detach()
|
||||
|
||||
# total number of dim
|
||||
tot_dim = self._input_numb_field * self._output_numb_field
|
||||
|
||||
for batch_idx, x in enumerate(X):
|
||||
|
||||
# extract mapped points
|
||||
stacked_input, indeces_channels = self._extract_mapped_points(
|
||||
batch_idx, self._index, x)
|
||||
|
||||
# compute the transpose convolution
|
||||
|
||||
# total number of fields
|
||||
res_tmp = []
|
||||
|
||||
# for each field
|
||||
for idx_conv in range(tot_dim):
|
||||
# index for each output field
|
||||
idx = idx_conv % self._output_numb_field
|
||||
# index for each input field
|
||||
idx_in = idx_conv % self._input_numb_field
|
||||
# extract input for each field
|
||||
single_channel_input = stacked_input[idx]
|
||||
rep_idx = torch.tensor(indeces_channels[idx])
|
||||
integral = integrals[batch_idx,
|
||||
idx_in, :].repeat_interleave(rep_idx)
|
||||
# extract filter
|
||||
net = self._net[idx_conv]
|
||||
# perform transpose convolution for all strides in one field
|
||||
staked_output = net(single_channel_input[..., :-1]).flatten()
|
||||
integral = staked_output * integral
|
||||
res_tmp.append(integral)
|
||||
|
||||
# stacking integral results and sum
|
||||
# filters (for each input fields) in groups
|
||||
# for different output fields
|
||||
res_tmp = torch.stack(res_tmp).reshape(self._input_numb_field,
|
||||
self._output_numb_field,
|
||||
-1).sum(0)
|
||||
conv_transposed[batch_idx, ..., -1] = res_tmp
|
||||
|
||||
return conv_transposed
|
||||
|
||||
def transpose_overlap(self, integrals, X):
|
||||
"""Transpose pass in the layer for overlapping filters
|
||||
|
||||
:param integrals: Weights for the transpose convolution. Shape
|
||||
[B x N_in x N]
|
||||
where B is the batch_size, N_in is the number of input
|
||||
fields, N the number of points in the mesh, D the dimension
|
||||
of the problem.
|
||||
:type integral: torch.tensor
|
||||
:param X: Input data. Expect tensor of shape
|
||||
[B x N_in x M x D] where B is the batch_size,
|
||||
N_in is the number of input fields, M the number of points
|
||||
in the mesh, D the dimension of the problem. Note, last column
|
||||
:type X: torch.tensor
|
||||
:return: Feed forward transpose convolution. Tensor of shape
|
||||
[B x N_out x N] where B is the batch_size,
|
||||
N_out is the number of output fields, N the number of points
|
||||
in the mesh, D the dimension of the problem.
|
||||
:rtype: torch.tensor
|
||||
|
||||
.. note:: This function is automatically called when `.transpose()`
|
||||
method is used and `no_overlap=False`
|
||||
"""
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='inverse')
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'inverse')
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
conv_transposed = self._grid_transpose.clone().detach()
|
||||
|
||||
# list to iterate for calculating nn output
|
||||
tmp = [i for i in range(self._output_numb_field)]
|
||||
iterate_conv = [item for item in tmp for _ in range(
|
||||
self._input_numb_field)]
|
||||
|
||||
for batch_idx, x in enumerate(X):
|
||||
|
||||
# accumulator for the convolution on different batches
|
||||
accumulator_batch = torch.zeros(
|
||||
size=(self._grid_transpose.shape[1],
|
||||
self._grid_transpose.shape[2]),
|
||||
requires_grad=True,
|
||||
device=X.device,
|
||||
dtype=X.dtype).clone()
|
||||
|
||||
for stride_idx, current_stride in enumerate(self._stride):
|
||||
# indeces of points falling into filter range
|
||||
indeces = self._index[stride_idx][batch_idx]
|
||||
|
||||
# number of points for each channel
|
||||
numb_pts_channel = tuple(indeces.sum(dim=-1))
|
||||
|
||||
# extracting points for each channel
|
||||
point_stride = x[indeces]
|
||||
|
||||
# if no points to upsample we just skip
|
||||
if point_stride.nelement() == 0:
|
||||
continue
|
||||
|
||||
# mapping points in filter domain
|
||||
map_points_(point_stride[..., :-1], current_stride)
|
||||
|
||||
# input points for kernels
|
||||
# we split for extracting number of points for each channel
|
||||
nn_input_pts = point_stride[..., :-1].split(numb_pts_channel)
|
||||
|
||||
# accumulate partial convolution results for each field
|
||||
res_tmp = []
|
||||
|
||||
# for each channel field compute transpose convolution
|
||||
for idx_conv, idx_channel_out in enumerate(iterate_conv):
|
||||
|
||||
# index for input channels
|
||||
idx_channel_in = idx_conv % self._input_numb_field
|
||||
|
||||
# extract filter
|
||||
net = self._net[idx_conv]
|
||||
|
||||
# calculate filter value
|
||||
staked_output = net(nn_input_pts[idx_channel_out])
|
||||
|
||||
# perform integral for all strides in one field
|
||||
integral = staked_output * integrals[batch_idx,
|
||||
idx_channel_in,
|
||||
stride_idx]
|
||||
# append results
|
||||
res_tmp.append(integral.flatten())
|
||||
|
||||
# computing channel sum
|
||||
channel_sum = []
|
||||
start = 0
|
||||
for _ in range(self._output_numb_field):
|
||||
tmp = res_tmp[start:start + self._input_numb_field]
|
||||
tmp = torch.vstack(tmp).sum(dim=0)
|
||||
channel_sum.append(tmp)
|
||||
start += self._input_numb_field
|
||||
|
||||
# accumulate the results
|
||||
accumulator_batch[indeces] += torch.hstack(channel_sum)
|
||||
|
||||
# save results of accumulation for each batch
|
||||
conv_transposed[batch_idx, ..., -1] = accumulator_batch
|
||||
|
||||
return conv_transposed
|
||||
63
pina/model/layers/integral.py
Normal file
63
pina/model/layers/integral.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Integral(object):
|
||||
|
||||
def __init__(self, param):
|
||||
"""Integral class for continous convolution
|
||||
|
||||
:param param: type of continuous convolution
|
||||
:type param: string
|
||||
"""
|
||||
|
||||
if param == 'discrete':
|
||||
self.make_integral = self.integral_param_disc
|
||||
elif param == 'continuous':
|
||||
self.make_integral = self.integral_param_cont
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
return self.make_integral(*args, **kwds)
|
||||
|
||||
def _prepend_zero(self, x):
|
||||
"""Create bins for performing integral
|
||||
|
||||
:param x: input tensor
|
||||
:type x: torch.tensor
|
||||
:return: bins for integrals
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
return torch.cat((torch.zeros(1, dtype=x.dtype, device=x.device), x))
|
||||
|
||||
def integral_param_disc(self, x, y, idx):
|
||||
"""Perform discretize integral
|
||||
with discrete parameters
|
||||
|
||||
:param x: input vector
|
||||
:type x: torch.tensor
|
||||
:param y: input vector
|
||||
:type y: torch.tensor
|
||||
:param idx: indeces for different strides
|
||||
:type idx: list
|
||||
:return: integral
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
cs_idxes = self._prepend_zero(torch.cumsum(torch.tensor(idx), 0))
|
||||
cs = self._prepend_zero(torch.cumsum(x.flatten() * y.flatten(), 0))
|
||||
return cs[cs_idxes[1:]] - cs[cs_idxes[:-1]]
|
||||
|
||||
def integral_param_cont(self, x, y, idx):
|
||||
"""Perform discretize integral for continuous convolution
|
||||
with continuous parameters
|
||||
|
||||
:param x: input vector
|
||||
:type x: torch.tensor
|
||||
:param y: input vector
|
||||
:type y: torch.tensor
|
||||
:param idx: indeces for different strides
|
||||
:type idx: list
|
||||
:return: integral
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
82
pina/model/layers/stride.py
Normal file
82
pina/model/layers/stride.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Stride(object):
|
||||
|
||||
def __init__(self, dict):
|
||||
"""Stride class for continous convolution
|
||||
|
||||
:param param: type of continuous convolution
|
||||
:type param: string
|
||||
"""
|
||||
|
||||
self._dict_stride = dict
|
||||
self._stride_continuous = None
|
||||
self._stride_discrete = self._create_stride_discrete(dict)
|
||||
|
||||
def _create_stride_discrete(self, my_dict):
|
||||
"""Creating the list for applying the filter
|
||||
|
||||
:param my_dict: Dictionary with the following arguments:
|
||||
domain size, starting position of the filter, jump size
|
||||
for the filter and direction of the filter
|
||||
:type my_dict: dict
|
||||
:raises IndexError: Values in the dict must have all same length
|
||||
:raises ValueError: Domain values must be greater than 0
|
||||
:raises ValueError: Direction must be either equal to 1, -1 or 0
|
||||
:raises IndexError: Direction and jumps must have zero in the same
|
||||
index
|
||||
:return: list of positions for the filter
|
||||
:rtype: list
|
||||
:Example:
|
||||
|
||||
|
||||
>>> stride = {"domain": [4, 4],
|
||||
"start": [-4, 2],
|
||||
"jump": [2, 2],
|
||||
"direction": [1, 1],
|
||||
}
|
||||
>>> create_stride(stride)
|
||||
[[-4.0, 2.0], [-4.0, 4.0], [-2.0, 2.0], [-2.0, 4.0]]
|
||||
"""
|
||||
|
||||
# we must check boundaries of the input as well
|
||||
|
||||
domain, start, jumps, direction = my_dict.values()
|
||||
|
||||
# checking
|
||||
|
||||
if not all([len(s) == len(domain) for s in my_dict.values()]):
|
||||
raise IndexError("values in the dict must have all same length")
|
||||
|
||||
if not all(v >= 0 for v in domain):
|
||||
raise ValueError("domain values must be greater than 0")
|
||||
|
||||
if not all(v == 1 or v == -1 or v == 0 for v in direction):
|
||||
raise ValueError("direction must be either equal to 1, -1 or 0")
|
||||
|
||||
seq_jumps = [i for i, e in enumerate(jumps) if e == 0]
|
||||
seq_direction = [i for i, e in enumerate(direction) if e == 0]
|
||||
|
||||
if seq_direction != seq_jumps:
|
||||
raise IndexError(
|
||||
"direction and jumps must have zero in the same index")
|
||||
|
||||
if seq_jumps:
|
||||
for i in seq_jumps:
|
||||
jumps[i] = domain[i]
|
||||
direction[i] = 1
|
||||
|
||||
# creating the stride grid
|
||||
values_mesh = [torch.arange(0, i, step).float()
|
||||
for i, step in zip(domain, jumps)]
|
||||
|
||||
values_mesh = [single * dim for single,
|
||||
dim in zip(values_mesh, direction)]
|
||||
|
||||
mesh = torch.meshgrid(values_mesh)
|
||||
coordinates_mesh = [x.reshape(-1, 1) for x in mesh]
|
||||
|
||||
stride = torch.cat(coordinates_mesh, dim=1) + torch.tensor(start)
|
||||
|
||||
return stride
|
||||
48
pina/model/layers/utils_convolution.py
Normal file
48
pina/model/layers/utils_convolution.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
|
||||
|
||||
def check_point(x, current_stride, dim):
|
||||
max_stride = current_stride + dim
|
||||
indeces = torch.logical_and(x[..., :-1] < max_stride,
|
||||
x[..., :-1] >= current_stride).all(dim=-1)
|
||||
return indeces
|
||||
|
||||
|
||||
def map_points_(x, filter_position):
|
||||
"""Mapping function n dimensional case
|
||||
|
||||
:param x: input data of two dimension
|
||||
:type x: torch.tensor
|
||||
:param filter_position: position of the filter
|
||||
:type dim: list[numeric]
|
||||
:return: data mapped inplace
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
x.add_(-filter_position)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def optimizing(f):
|
||||
"""Decorator for calling a function just once
|
||||
|
||||
:param f: python function
|
||||
:type f: function
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
if kwargs['type'] == 'forward':
|
||||
if not wrapper.has_run_inverse:
|
||||
wrapper.has_run_inverse = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if kwargs['type'] == 'inverse':
|
||||
if not wrapper.has_run:
|
||||
wrapper.has_run = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
wrapper.has_run_inverse = False
|
||||
wrapper.has_run = False
|
||||
|
||||
return wrapper
|
||||
140
tests/test_conv.py
Normal file
140
tests/test_conv.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from pina.model.layers import ContinuousConv
|
||||
import torch
|
||||
|
||||
|
||||
def prod(iterable):
|
||||
p = 1
|
||||
for n in iterable:
|
||||
p *= n
|
||||
return p
|
||||
|
||||
|
||||
def make_grid(x):
|
||||
def _transform_image(image):
|
||||
|
||||
# extracting image info
|
||||
channels, dimension = image.size()[0], image.size()[1:]
|
||||
|
||||
# initializing transfomed image
|
||||
coordinates = torch.zeros(
|
||||
[channels, prod(dimension), len(dimension) + 1]).to(image.device)
|
||||
|
||||
# creating the n dimensional mesh grid
|
||||
values_mesh = [torch.arange(0, dim).float().to(
|
||||
image.device) for dim in dimension]
|
||||
mesh = torch.meshgrid(values_mesh)
|
||||
coordinates_mesh = [x.reshape(-1, 1) for x in mesh]
|
||||
coordinates_mesh.append(0)
|
||||
|
||||
for count, channel in enumerate(image):
|
||||
coordinates_mesh[-1] = channel.reshape(-1, 1)
|
||||
coordinates[count] = torch.cat(coordinates_mesh, dim=1)
|
||||
|
||||
return coordinates
|
||||
|
||||
output = [_transform_image(current_image) for current_image in x]
|
||||
return torch.stack(output).to(x.device)
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self. model = torch.nn.Sequential(torch.nn.Linear(2, 8),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(8, 8),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(8, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
# INPUTS
|
||||
channel_input = 2
|
||||
channel_output = 6
|
||||
batch = 2
|
||||
N = 10
|
||||
dim = [3, 3]
|
||||
stride = {"domain": [10, 10],
|
||||
"start": [0, 0],
|
||||
"jumps": [3, 3],
|
||||
"direction": [1, 1.]}
|
||||
dim_filter = len(dim)
|
||||
dim_input = (batch, channel_input, 10, dim_filter)
|
||||
dim_output = (batch, channel_output, 4, dim_filter)
|
||||
x = torch.rand(dim_input)
|
||||
x = make_grid(x)
|
||||
|
||||
|
||||
def test_constructor():
|
||||
model = MLP
|
||||
|
||||
conv = ContinuousConv(channel_input,
|
||||
channel_output,
|
||||
dim,
|
||||
stride,
|
||||
model=model)
|
||||
conv = ContinuousConv(channel_input,
|
||||
channel_output,
|
||||
dim,
|
||||
stride,
|
||||
model=None)
|
||||
|
||||
|
||||
def test_forward():
|
||||
model = MLP
|
||||
|
||||
# simple forward
|
||||
conv = ContinuousConv(channel_input,
|
||||
channel_output,
|
||||
dim,
|
||||
stride,
|
||||
model=model)
|
||||
conv(x)
|
||||
|
||||
# simple forward with optimization
|
||||
conv = ContinuousConv(channel_input,
|
||||
channel_output,
|
||||
dim,
|
||||
stride,
|
||||
model=model,
|
||||
optimize=True)
|
||||
conv(x)
|
||||
|
||||
|
||||
def test_transpose():
|
||||
model = MLP
|
||||
|
||||
# simple transpose
|
||||
conv = ContinuousConv(channel_input,
|
||||
channel_output,
|
||||
dim,
|
||||
stride,
|
||||
model=model)
|
||||
|
||||
conv2 = ContinuousConv(channel_output,
|
||||
channel_input,
|
||||
dim,
|
||||
stride,
|
||||
model=model)
|
||||
|
||||
integrals = conv(x)
|
||||
conv2.transpose(integrals[..., -1], x)
|
||||
|
||||
stride_no_overlap = {"domain": [10, 10],
|
||||
"start": [0, 0],
|
||||
"jumps": dim,
|
||||
"direction": [1, 1.]}
|
||||
|
||||
# simple transpose with optimization
|
||||
# conv = ContinuousConv(channel_input,
|
||||
# channel_output,
|
||||
# dim,
|
||||
# stride_no_overlap,
|
||||
# model=model,
|
||||
# optimize=True,
|
||||
# no_overlap=True)
|
||||
|
||||
# integrals = conv(x)
|
||||
# conv.transpose(integrals[..., -1], x)
|
||||
1
tutorials/README.md
vendored
1
tutorials/README.md
vendored
@@ -8,5 +8,6 @@ In this folder we collect useful tutorials in order to understand the principles
|
||||
| Tutorial1 [[.ipynb](tutorial1/tutorial.ipynb), [.py](tutorial1/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial1/tutorial.html)]| Introduction to PINA features | `SpatialProblem` |
|
||||
| Tutorial2 [[.ipynb](tutorial2/tutorial.ipynb), [.py](tutorial2/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial2/tutorial.html)]| Poisson problem on regular domain using extra features | `SpatialProblem` |
|
||||
| Tutorial3 [[.ipynb](tutorial3/tutorial.ipynb), [.py](tutorial3/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial3/tutorial.html)]| Wave problem on regular domain using custom pytorch networks. | `SpatialProblem`, `TimeDependentProblem` |
|
||||
| Tutorial4 [[.ipynb](tutorial4/tutorial.ipynb), [.py](tutorial4/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial4/tutorial.html)]| Continuous Convolutional Filter usage. | `None` |
|
||||
|
||||
|
||||
|
||||
1031
tutorials/tutorial4/tutorial.ipynb
vendored
Normal file
1031
tutorials/tutorial4/tutorial.ipynb
vendored
Normal file
File diff suppressed because one or more lines are too long
638
tutorials/tutorial4/tutorial.py
vendored
Normal file
638
tutorials/tutorial4/tutorial.py
vendored
Normal file
@@ -0,0 +1,638 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# # Tutorial 4: continuous convolutional filter
|
||||
|
||||
# In this tutorial we will show how to use the Continouous Convolutional Filter, and how to build common Deep Learning architectures with it. The implementation of the filter follows the original work [**A Continuous Convolutional Trainable Filter for Modelling Unstructured Data**](https://arxiv.org/abs/2210.13416) of Coscia Dario, Laura Meneghetti, Nicola Demo, Giovanni Stabile, and Gianluigi Rozza.
|
||||
|
||||
# First of all we import the modules needed for the tutorial, which include:
|
||||
#
|
||||
# * `ContinuousConv` class from `pina.model.layers` which implements the continuous convolutional filter
|
||||
# * `PyTorch` and `Matplotlib` for tensorial operations and visualization respectively
|
||||
|
||||
# In[1]:
|
||||
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from pina.model.layers import ContinuousConv
|
||||
import torchvision # for MNIST dataset
|
||||
from pina.model import FeedForward # for building AE and MNIST classification
|
||||
|
||||
|
||||
# The tutorial is structured as follow:
|
||||
# * [Continuous filter background](#continuous-filter-background): understand how the convolutional filter works and how to use it.
|
||||
# * [Building a MNIST Classifier](#building-a-mnist-classifier): show how to build a simple classifier using the MNIST dataset and how to combine a continuous convolutional layer with a feedforward neural network.
|
||||
# * [Building a Continuous Convolutional Autoencoder](#building-a-continuous-convolutional-autoencoder): show how to use the continuous filter to work with unstructured data for autoencoding and up-sampling.
|
||||
|
||||
# ## Continuous filter background
|
||||
|
||||
# As reported by the authors in the original paper: in contrast to discrete convolution, continuous convolution is mathematically defined as:
|
||||
#
|
||||
# $$
|
||||
# \mathcal{I}_{\rm{out}}(\mathbf{x}) = \int_{\mathcal{X}} \mathcal{I}(\mathbf{x} + \mathbf{\tau}) \cdot \mathcal{K}(\mathbf{\tau}) d\mathbf{\tau},
|
||||
# $$
|
||||
# where $\mathcal{K} : \mathcal{X} \rightarrow \mathbb{R}$ is the *continuous filter* function, and $\mathcal{I} : \Omega \subset \mathbb{R}^N \rightarrow \mathbb{R}$ is the input function. The continuous filter function is approximated using a FeedForward Neural Network, thus trainable during the training phase. The way in which the integral is approximated can be different, currently on **PINA** we approximate it using a simple sum, as suggested by the authors. Thus, given $\{\mathbf{x}_i\}_{i=1}^{n}$ points in $\mathbb{R}^N$ of the input function mapped on the $\mathcal{X}$ filter domain, we approximate the above equation as:
|
||||
# $$
|
||||
# \mathcal{I}_{\rm{out}}(\mathbf{\tilde{x}}_i) = \sum_{{\mathbf{x}_i}\in\mathcal{X}} \mathcal{I}(\mathbf{x}_i + \mathbf{\tau}) \cdot \mathcal{K}(\mathbf{x}_i),
|
||||
# $$
|
||||
# where $\mathbf{\tau} \in \mathcal{S}$, with $\mathcal{S}$ the set of available strides, corresponds to the current stride position of the filter, and $\mathbf{\tilde{x}}_i$ points are obtained by taking the centroid of the filter position mapped on the $\Omega$ domain.
|
||||
|
||||
# We will now try to pratically see how to work with the filter. From the above definition we see that what is needed is:
|
||||
# 1. A domain and a function defined on that domain (the input)
|
||||
# 2. A stride, corresponding to the positions where the filter needs to be $\rightarrow$ `stride` variable in `ContinuousConv`
|
||||
# 3. The filter rectangular domain $\rightarrow$ `filter_dim` variable in `ContinuousConv`
|
||||
|
||||
# ### Input function
|
||||
#
|
||||
# The input function for the continuous filter is defined as a tensor of shape: $$[B \times N_{in} \times N \times D]$$ where $B$ is the batch_size, $N_{in}$ is the number of input fields, $N$ the number of points in the mesh, $D$ the dimension of the problem. In particular:
|
||||
# * $D$ is the number of spatial variables + 1. The last column must contain the field value. For example for 2D problems $D=3$ and the tensor will be something like `[first coordinate, second coordinate, field value]`
|
||||
# * $N_{in}$ represents the number of vectorial function presented. For example a vectorial function $f = [f_1, f_2]$ will have $N_{in}=2$
|
||||
#
|
||||
# Let's see an example to clear the ideas. We will be verbose to explain in details the input form. We wish to create the function:
|
||||
# $$
|
||||
# f(x, y) = [\sin(\pi x) \sin(\pi y), -\sin(\pi x) \sin(\pi y)] \quad (x,y)\in[0,1]\times[0,1]
|
||||
# $$
|
||||
#
|
||||
# using a batch size of one.
|
||||
|
||||
# In[2]:
|
||||
|
||||
|
||||
# batch size fixed to 1
|
||||
batch_size = 1
|
||||
|
||||
# points in the mesh fixed to 200
|
||||
N = 200
|
||||
|
||||
# vectorial 2 dimensional function, number_input_fileds=2
|
||||
number_input_fileds = 2
|
||||
|
||||
# 2 dimensional spatial variables, D = 2 + 1 = 3
|
||||
D = 3
|
||||
|
||||
# create the function f domain as random 2d points in [0, 1]
|
||||
domain = torch.rand(size=(batch_size, number_input_fileds, N, D-1))
|
||||
print(f"Domain has shape: {domain.shape}")
|
||||
|
||||
# create the functions
|
||||
pi = torch.acos(torch.tensor([-1.])) # pi value
|
||||
f1 = torch.sin(pi * domain[:, 0, :, 0]) * torch.sin(pi * domain[:, 0, :, 1])
|
||||
f2 = - torch.sin(pi * domain[:, 1, :, 0]) * torch.sin(pi * domain[:, 1, :, 1])
|
||||
|
||||
# stacking the input domain and field values
|
||||
data = torch.empty(size=(batch_size, number_input_fileds, N, D))
|
||||
data[..., :-1] = domain # copy the domain
|
||||
data[:, 0, :, -1] = f1 # copy first field value
|
||||
data[:, 1, :, -1] = f1 # copy second field value
|
||||
print(f"Filter input data has shape: {data.shape}")
|
||||
|
||||
|
||||
# ### Stride
|
||||
#
|
||||
# The stride is passed as a dictionary `stride` which tells the filter where to go. Here is an example for the $[0,1]\times[0,5]$ domain:
|
||||
#
|
||||
# ```python
|
||||
# # stride definition
|
||||
# stride = {"domain": [1, 5],
|
||||
# "start": [0, 0],
|
||||
# "jump": [0.1, 0.3],
|
||||
# "direction": [1, 1],
|
||||
# }
|
||||
# ```
|
||||
# This tells the filter:
|
||||
# 1. `domain`: square domain (the only implemented) $[0,1]\times[0,5]$. The minimum value is always zero, while the maximum is specified by the user
|
||||
# 2. `start`: start position of the filter, coordinate $(0, 0)$
|
||||
# 3. `jump`: the jumps of the centroid of the filter to the next position $(0.1, 0.3)$
|
||||
# 4. `direction`: the directions of the jump, with `1 = right`, `0 = no jump`,`-1 = left` with respect to the current position
|
||||
#
|
||||
# **Note**
|
||||
#
|
||||
# We are planning to release the possibility to directly pass a list of possible strides!
|
||||
|
||||
# ### Filter definition
|
||||
#
|
||||
# Having defined all the previous blocks we are able to construct the continuous filter.
|
||||
#
|
||||
# Suppose we would like to get an ouput with only one field, and let us fix the filter dimension to be $[0.1, 0.1]$.
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
# filter dim
|
||||
filter_dim = [0.1, 0.1]
|
||||
|
||||
# stride
|
||||
stride = {"domain": [1, 1],
|
||||
"start": [0, 0],
|
||||
"jump": [0.08, 0.08],
|
||||
"direction": [1, 1],
|
||||
}
|
||||
|
||||
# creating the filter
|
||||
cConv = ContinuousConv(input_numb_field=number_input_fileds,
|
||||
output_numb_field=1,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride)
|
||||
|
||||
|
||||
# That's it! In just one line of code we have created the continuous convolutional filter. By default the `pina.model.FeedForward` neural network is intitialised, more on the [documentation](https://mathlab.github.io/PINA/_rst/fnn.html). In case the mesh doesn't change during training we can set the `optimize` flag equals to `True`, to exploit optimizations for finding the points to convolve.
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
# creating the filter + optimization
|
||||
cConv = ContinuousConv(input_numb_field=number_input_fileds,
|
||||
output_numb_field=1,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
optimize=True)
|
||||
|
||||
|
||||
# Let's try to do a forward pass
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
print(f"Filter input data has shape: {data.shape}")
|
||||
|
||||
#input to the filter
|
||||
output = cConv(data)
|
||||
|
||||
print(f"Filter output data has shape: {output.shape}")
|
||||
|
||||
|
||||
# If we don't want to use the default `FeedForward` neural network, we can pass a specified torch model in the `model` keyword as follow:
|
||||
#
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
class SimpleKernel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self. model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
cConv = ContinuousConv(input_numb_field=number_input_fileds,
|
||||
output_numb_field=1,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
optimize=True,
|
||||
model=SimpleKernel)
|
||||
|
||||
|
||||
# Notice that we pass the class and not an already built object!
|
||||
|
||||
# ## Building a MNIST Classifier
|
||||
#
|
||||
# Let's see how we can build a MNIST classifier using a continuous convolutional filter. We will use the MNIST dataset from PyTorch. In order to keep small training times we use only 6000 samples for training and 1000 samples for testing.
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
from torch.utils.data import DataLoader, SubsetRandomSampler
|
||||
|
||||
numb_training = 6000 # get just 6000 images for training
|
||||
numb_testing= 1000 # get just 1000 images for training
|
||||
seed = 111 # for reproducibility
|
||||
batch_size = 8 # setting batch size
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# downloading the dataset
|
||||
train_data = torchvision.datasets.MNIST('./data/', train=True, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize(
|
||||
(0.1307,), (0.3081,))
|
||||
]))
|
||||
subsample_train_indices = torch.randperm(len(train_data))[:numb_training]
|
||||
train_loader = DataLoader(train_data, batch_size=batch_size,
|
||||
sampler=SubsetRandomSampler(subsample_train_indices))
|
||||
|
||||
test_data = torchvision.datasets.MNIST('./data/', train=False, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize(
|
||||
(0.1307,), (0.3081,))
|
||||
]))
|
||||
subsample_test_indices = torch.randperm(len(train_data))[:numb_testing]
|
||||
test_loader = DataLoader(train_data, batch_size=batch_size,
|
||||
sampler=SubsetRandomSampler(subsample_train_indices))
|
||||
|
||||
|
||||
# Let's now build a simple classifier. The MNIST dataset is composed by vectors of shape `[batch, 1, 28, 28]`, but we can image them as one field functions where the pixels $ij$ are the coordinate $x=i, y=j$ in a $[0, 27]\times[0,27]$ domain, and the pixels value are the field values. We just need a function to transform the regular tensor in a tensor compatible for the continuous filter:
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
def transform_input(x):
|
||||
batch_size = x.shape[0]
|
||||
dim_grid = tuple(x.shape[:-3:-1])
|
||||
|
||||
# creating the n dimensional mesh grid for a single channel image
|
||||
values_mesh = [torch.arange(0, dim).float() for dim in dim_grid]
|
||||
mesh = torch.meshgrid(values_mesh)
|
||||
coordinates_mesh = [x.reshape(-1, 1) for x in mesh]
|
||||
coordinates = torch.cat(coordinates_mesh, dim=1).unsqueeze(
|
||||
0).repeat((batch_size, 1, 1)).unsqueeze(1)
|
||||
|
||||
return torch.cat((coordinates, x.flatten(2).unsqueeze(-1)), dim=-1)
|
||||
|
||||
|
||||
# let's try it out
|
||||
image, s = next(iter(train_loader))
|
||||
print(f"Original MNIST image shape: {image.shape}")
|
||||
|
||||
image_transformed = transform_input(image)
|
||||
print(f"Transformed MNIST image shape: {image_transformed.shape}")
|
||||
|
||||
|
||||
# We can now build a simple classifier! We will use just one convolutional filter followed by a feedforward neural network
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
class ContinuousClassifier(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# number of classes for classification
|
||||
numb_class = 10
|
||||
|
||||
# convolutional block
|
||||
self.convolution = ContinuousConv(input_numb_field=1,
|
||||
output_numb_field=4,
|
||||
stride={"domain": [27, 27],
|
||||
"start": [0, 0],
|
||||
"jumps": [4, 4],
|
||||
"direction": [1, 1.],
|
||||
},
|
||||
filter_dim=[4, 4],
|
||||
optimize=True)
|
||||
# feedforward net
|
||||
self.nn = FeedForward(input_variables=196,
|
||||
output_variables=numb_class,
|
||||
layers=[120, 64],
|
||||
func=torch.nn.ReLU)
|
||||
|
||||
def forward(self, x):
|
||||
# transform input + convolution
|
||||
x = transform_input(x)
|
||||
x = self.convolution(x)
|
||||
# feed forward classification
|
||||
return self.nn(x[..., -1].flatten(1))
|
||||
|
||||
|
||||
net = ContinuousClassifier()
|
||||
|
||||
|
||||
# Let's try to train it using a simple pytorch training loop. We train for juts 1 epoch using Adam optimizer with a $0.001$ learning rate.
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# optimizer and loss function
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
for epoch in range(1): # loop over the dataset multiple times
|
||||
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
# get the inputs; data is a list of [inputs, labels]
|
||||
inputs, labels = data
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
running_loss += loss.item()
|
||||
if i % 50 == 49:
|
||||
print(
|
||||
f'epoch [{i + 1}/{numb_training//batch_size}] loss[{running_loss / 500:.3f}]')
|
||||
running_loss = 0.0
|
||||
|
||||
|
||||
# Let's see the performance on the train set!
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for data in test_loader:
|
||||
images, labels = data
|
||||
# calculate outputs by running images through the network
|
||||
outputs = net(images)
|
||||
# the class with the highest energy is what we choose as prediction
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
print(
|
||||
f'Accuracy of the network on the 1000 test images: {(correct / total):.3%}')
|
||||
|
||||
|
||||
# As we can see we have very good performance for having traing only for 1 epoch! Nevertheless, we are still using structured data... Let's see how we can build an autoencoder for unstructured data now.
|
||||
|
||||
# ## Building a Continuous Convolutional Autoencoder
|
||||
#
|
||||
# Just as toy problem, we will now build an autoencoder for the following function $f(x,y)=\sin(\pi x)\sin(\pi y)$ on the unit circle domain centered in $(0.5, 0.5)$. We will also see the ability to up-sample (once trained) the results without retraining. Let's first create the input and visualize it, we will use firstly a mesh of $100$ points.
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
# create inputs
|
||||
def circle_grid(N=100):
|
||||
"""Generate points withing a unit 2D circle centered in (0.5, 0.5)
|
||||
|
||||
:param N: number of points
|
||||
:type N: float
|
||||
:return: [x, y] array of points
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
|
||||
PI = torch.acos(torch.zeros(1)).item() * 2
|
||||
R = 0.5
|
||||
centerX = 0.5
|
||||
centerY = 0.5
|
||||
|
||||
r = R * torch.sqrt(torch.rand(N))
|
||||
theta = torch.rand(N) * 2 * PI
|
||||
|
||||
x = centerX + r * torch.cos(theta)
|
||||
y = centerY + r * torch.sin(theta)
|
||||
|
||||
return torch.stack([x, y]).T
|
||||
|
||||
# create the grid
|
||||
grid = circle_grid(500)
|
||||
|
||||
# create input
|
||||
input_data = torch.empty(size=(1, 1, grid.shape[0], 3))
|
||||
input_data[0, 0, :, :-1] = grid
|
||||
input_data[0, 0, :, -1] = torch.sin(pi * grid[:, 0]) * torch.sin(pi * grid[:, 1])
|
||||
|
||||
# visualize data
|
||||
plt.title("Training sample with 500 points")
|
||||
plt.scatter(grid[:, 0], grid[:, 1], c=input_data[0, 0, :, -1])
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
# Let's now build a simple autoencoder using the continuous convolutional filter. The data is clearly unstructured and a simple convolutional filter might not work without projecting or interpolating first. Let's first build and `Encoder` and `Decoder` class, and then a `Autoencoder` class that contains both.
|
||||
|
||||
# In[23]:
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, hidden_dimension):
|
||||
super().__init__()
|
||||
|
||||
# convolutional block
|
||||
self.convolution = ContinuousConv(input_numb_field=1,
|
||||
output_numb_field=2,
|
||||
stride={"domain": [1, 1],
|
||||
"start": [0, 0],
|
||||
"jumps": [0.05, 0.05],
|
||||
"direction": [1, 1.],
|
||||
},
|
||||
filter_dim=[0.15, 0.15],
|
||||
optimize=True)
|
||||
# feedforward net
|
||||
self.nn = FeedForward(input_variables=400,
|
||||
output_variables=hidden_dimension,
|
||||
layers=[240, 120])
|
||||
|
||||
def forward(self, x):
|
||||
# convolution
|
||||
x = self.convolution(x)
|
||||
# feed forward pass
|
||||
return self.nn(x[..., -1])
|
||||
|
||||
|
||||
class Decoder(torch.nn.Module):
|
||||
def __init__(self, hidden_dimension):
|
||||
super().__init__()
|
||||
|
||||
# convolutional block
|
||||
self.convolution = ContinuousConv(input_numb_field=2,
|
||||
output_numb_field=1,
|
||||
stride={"domain": [1, 1],
|
||||
"start": [0, 0],
|
||||
"jumps": [0.05, 0.05],
|
||||
"direction": [1, 1.],
|
||||
},
|
||||
filter_dim=[0.15, 0.15],
|
||||
optimize=True)
|
||||
# feedforward net
|
||||
self.nn = FeedForward(input_variables=hidden_dimension,
|
||||
output_variables=400,
|
||||
layers=[120, 240])
|
||||
|
||||
def forward(self, weights, grid):
|
||||
# feed forward pass
|
||||
x = self.nn(weights)
|
||||
# transpose convolution
|
||||
return torch.sigmoid(self.convolution.transpose(x, grid))
|
||||
|
||||
|
||||
# Very good! Notice that in the `Decoder` class in the `forward` pass we have used the `.transpose()` method of the `ContinuousConvolution` class. This method accepts the `weights` for upsampling and the `grid` on where to upsample. Let's now build the autoencoder! We set the hidden dimension in the `hidden_dimension` variable. We apply the sigmoid on the output since the field value is between $[0, 1]$.
|
||||
|
||||
# In[28]:
|
||||
|
||||
|
||||
class Autoencoder(torch.nn.Module):
|
||||
def __init__(self, hidden_dimension=10):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(hidden_dimension)
|
||||
self.decoder = Decoder(hidden_dimension)
|
||||
|
||||
def forward(self, x):
|
||||
# saving grid for later upsampling
|
||||
grid = x.clone().detach()
|
||||
# encoder
|
||||
weights = self.encoder(x)
|
||||
# decoder
|
||||
out = self.decoder(weights, grid)
|
||||
return out
|
||||
|
||||
|
||||
net = Autoencoder()
|
||||
|
||||
|
||||
# Let's now train the autoencoder, minimizing the mean square error loss and optimizing using Adam.
|
||||
|
||||
# In[29]:
|
||||
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# optimizer and loss function
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
|
||||
criterion = torch.nn.MSELoss()
|
||||
max_epochs = 150
|
||||
|
||||
for epoch in range(max_epochs): # loop over the dataset multiple times
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
outputs = net(input_data)
|
||||
loss = criterion(outputs[..., -1], input_data[..., -1])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
if epoch % 10 ==9:
|
||||
print(f'epoch [{epoch + 1}/{max_epochs}] loss [{loss.item():.2}]')
|
||||
|
||||
|
||||
# Let's visualize the two solutions side by side!
|
||||
|
||||
# In[30]:
|
||||
|
||||
|
||||
net.eval()
|
||||
|
||||
# get output and detach from computational graph for plotting
|
||||
output = net(input_data).detach()
|
||||
|
||||
# visualize data
|
||||
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
|
||||
pic1 = axes[0].scatter(grid[:, 0], grid[:, 1], c=input_data[0, 0, :, -1])
|
||||
axes[0].set_title("Real")
|
||||
fig.colorbar(pic1)
|
||||
plt.subplot(1, 2, 2)
|
||||
pic2 = axes[1].scatter(grid[:, 0], grid[:, 1], c=output[0, 0, :, -1])
|
||||
axes[1].set_title("Autoencoder")
|
||||
fig.colorbar(pic2)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
# As we can see the two are really similar! We can compute the $l_2$ error quite easily as well:
|
||||
|
||||
# In[32]:
|
||||
|
||||
|
||||
def l2_error(input_, target):
|
||||
return torch.linalg.norm(input_-target, ord=2)/torch.linalg.norm(input_, ord=2)
|
||||
|
||||
|
||||
print(f'l2 error: {l2_error(input_data[0, 0, :, -1], output[0, 0, :, -1]):.2%}')
|
||||
|
||||
|
||||
# More or less $4\%$ in $l_2$ error, which is really low considering the fact that we use just **one** convolutional layer and a simple feedforward to decrease the dimension. Let's see now some peculiarity of the filter.
|
||||
|
||||
# ### Filter for upsampling
|
||||
#
|
||||
# Suppose we have already the hidden dimension and we want to upsample on a differen grid with more points. Let's see how to do it:
|
||||
|
||||
# In[33]:
|
||||
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
grid2 = circle_grid(1500) # triple number of points
|
||||
input_data2 = torch.zeros(size=(1, 1, grid2.shape[0], 3))
|
||||
input_data2[0, 0, :, :-1] = grid2
|
||||
input_data2[0, 0, :, -1] = torch.sin(pi *
|
||||
grid2[:, 0]) * torch.sin(pi * grid2[:, 1])
|
||||
|
||||
# get the hidden dimension representation from original input
|
||||
latent = net.encoder(input_data)
|
||||
|
||||
# upsample on the second input_data2
|
||||
output = net.decoder(latent, input_data2).detach()
|
||||
|
||||
# show the picture
|
||||
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
|
||||
pic1 = axes[0].scatter(grid2[:, 0], grid2[:, 1], c=input_data2[0, 0, :, -1])
|
||||
axes[0].set_title("Real")
|
||||
fig.colorbar(pic1)
|
||||
plt.subplot(1, 2, 2)
|
||||
pic2 = axes[1].scatter(grid2[:, 0], grid2[:, 1], c=output[0, 0, :, -1])
|
||||
axes[1].set_title("Up-sampling")
|
||||
fig.colorbar(pic2)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
# As we can see we have a very good approximation of the original function, even thought some noise is present. Let's calculate the error now:
|
||||
|
||||
# In[34]:
|
||||
|
||||
|
||||
print(f'l2 error: {l2_error(input_data2[0, 0, :, -1], output[0, 0, :, -1]):.2%}')
|
||||
|
||||
|
||||
# ### Autoencoding at different resolution
|
||||
# In the previous example we already had the hidden dimension (of original input) and we used it to upsample. Sometimes however we have a more fine mesh solution and we simply want to encode it. This can be done without retraining! This procedure can be useful in case we have many points in the mesh and just a smaller part of them are needed for training. Let's see the results of this:
|
||||
|
||||
# In[36]:
|
||||
|
||||
|
||||
# setting the seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
grid2 = circle_grid(3500) # very fine mesh
|
||||
input_data2 = torch.zeros(size=(1, 1, grid2.shape[0], 3))
|
||||
input_data2[0, 0, :, :-1] = grid2
|
||||
input_data2[0, 0, :, -1] = torch.sin(pi *
|
||||
grid2[:, 0]) * torch.sin(pi * grid2[:, 1])
|
||||
|
||||
# get the hidden dimension representation from more fine mesh input
|
||||
latent = net.encoder(input_data2)
|
||||
|
||||
# upsample on the second input_data2
|
||||
output = net.decoder(latent, input_data2).detach()
|
||||
|
||||
# show the picture
|
||||
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
|
||||
pic1 = axes[0].scatter(grid2[:, 0], grid2[:, 1], c=input_data2[0, 0, :, -1])
|
||||
axes[0].set_title("Real")
|
||||
fig.colorbar(pic1)
|
||||
plt.subplot(1, 2, 2)
|
||||
pic2 = axes[1].scatter(grid2[:, 0], grid2[:, 1], c=output[0, 0, :, -1])
|
||||
axes[1].set_title("Autoencoder not re-trained")
|
||||
fig.colorbar(pic2)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# calculate l2 error
|
||||
print(
|
||||
f'l2 error: {l2_error(input_data2[0, 0, :, -1], output[0, 0, :, -1]):.2%}')
|
||||
|
||||
|
||||
# ## What's next?
|
||||
#
|
||||
# We have shown the basic usage of a convolutional filter. In the next tutorials we will show how to combine the PINA framework with the convolutional filter to train in few lines and efficiently a Neural Network!
|
||||
Reference in New Issue
Block a user