Files
PINA/pina/model/layers/convolution_2d.py
Dario Coscia 17464ceca9 New Residual Model and Fix relative import
* Adding Residual MLP
* Adding test Residual MLP
* Modified relative import Continuous Conv
2023-11-17 09:51:29 +01:00

548 lines
21 KiB
Python

"""Module for Continuous Convolution class"""
from .convolution import BaseContinuousConv
from .utils_convolution import check_point, map_points_
from .integral import Integral
import torch
class ContinuousConvBlock(BaseContinuousConv):
"""
Implementation of Continuous Convolutional operator.
.. seealso::
**Original reference**: Coscia, D., Meneghetti, L., Demo, N.,
Stabile, G., & Rozza, G.. (2022). A Continuous Convolutional Trainable
Filter for Modelling Unstructured Data.
DOI: `10.48550/arXiv.2210.13416
<https://doi.org/10.48550/arXiv.2210.13416>`_.
"""
def __init__(self, input_numb_field, output_numb_field,
filter_dim, stride, model=None, optimize=False,
no_overlap=False):
"""
:param input_numb_field: Number of fields N_in in the input.
:type input_numb_field: int
:param output_numb_field: Number of fields N_out in the output.
:type output_numb_field: int
:param filter_dim: Dimension of the filter.
:type filter_dim: tuple/ list
:param stride: Stride for the filter.
:type stride: dict
:param model: Neural network for inner parametrization,
defaults to None. If None, a default multilayer perceptron
is used, see BaseContinuousConv.DefaultKernel.
:type model: torch.nn.Module, optional
:param optimize: Flag for performing optimization on the continuous
filter, defaults to False. The flag `optimize=True` should be
used only when the scatter datapoints are fixed through the
training. If torch model is in `.eval()` mode, the flag is
automatically set to False always.
:type optimize: bool, optional
:param no_overlap: Flag for performing optimization on the transpose
continuous filter, defaults to False. The flag set to `True` should
be used only when the filter positions do not overlap for different
strides. RuntimeError will raise in case of non-compatible strides.
:type no_overlap: bool, optional
.. note::
Using `optimize=True` the filter can be use either in `forward`
or in `transpose` mode, not both. If `optimize=False` the same
filter can be used for both `transpose` and `forward` modes.
.. warning::
The algorithm expects input to be in the form: [B x N_in x N x D]
where B is the batch_size, N_in is the number of input
fields, N the number of points in the mesh, D the dimension
of the problem. In particular:
* D is the number of spatial variables + 1. The last column must
contain the field value. For example for 2D problems D=3 and
the tensor will be something like `[first coordinate, second
coordinate, field value]`.
* N_in represents the number of vectorial function presented.
For example a vectorial function f = [f_1, f_2] will have
N_in=2.
The algorithm returns a tensor of shape: [B x N_out x N x D]
where B is the batch_size, N_out is the number of output
fields, N' the number of points in the mesh, D the dimension
of the problem (coordinates + field value).
For example, a 2-dimensional vectorial function N_in=2 of
3-dimensionalcinput D=3+1=4 with 100 points input mesh and batch
size of 8 is represented as a tensor `[8, 2, 100, 4]`, where the
columnsc`[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
second filed value respectively.
:Example:
>>> class MLP(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self. model = torch.nn.Sequential(
torch.nn.Linear(2, 8),
torch.nn.ReLU(),
torch.nn.Linear(8, 8),
torch.nn.ReLU(),
torch.nn.Linear(8, 1))
def forward(self, x):
return self.model(x)
>>> dim = [3, 3]
>>> stride = {"domain": [10, 10],
"start": [0, 0],
"jumps": [3, 3],
"direction": [1, 1.]}
>>> conv = ContinuousConv2D(1, 2, dim, stride, MLP)
>>> conv
ContinuousConv2D(
(_net): ModuleList(
(0): MLP(
(model): Sequential(
(0): Linear(in_features=2, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=8, bias=True)
(3): ReLU()
(4): Linear(in_features=8, out_features=1, bias=True)
)
)
(1): MLP(
(model): Sequential(
(0): Linear(in_features=2, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=8, bias=True)
(3): ReLU()
(4): Linear(in_features=8, out_features=1, bias=True)
)
)
)
)
"""
super().__init__(input_numb_field=input_numb_field,
output_numb_field=output_numb_field,
filter_dim=filter_dim,
stride=stride,
model=model,
optimize=optimize,
no_overlap=no_overlap)
# integral routine
self._integral = Integral('discrete')
# create the network
self._net = self._spawn_networks(model)
# stride for continuous convolution overridden
self._stride = self._stride._stride_discrete
def _spawn_networks(self, model):
"""Private method to create a collection of kernels
:param model: a torch.nn.Module model in form of Object class
:type model: torch.nn.Module
:return: list of torch.nn.Module models
:rtype: torch.nn.ModuleList
"""
nets = []
if self._net is None:
for _ in range(self._input_numb_field * self._output_numb_field):
tmp = ContinuousConvBlock.DefaultKernel(len(self._dim), 1)
nets.append(tmp)
else:
if not isinstance(model, object):
raise ValueError("Expected a python class inheriting"
" from torch.nn.Module")
for _ in range(self._input_numb_field * self._output_numb_field):
tmp = model()
if not isinstance(tmp, torch.nn.Module):
raise ValueError("The python class must be inherited from"
" torch.nn.Module. See the docstring for"
" an example.")
nets.append(tmp)
return torch.nn.ModuleList(nets)
def _extract_mapped_points(self, batch_idx, index, x):
"""Priviate method to extract mapped points in the filter
:param x: input tensor [channel x N x dim]
:type x: torch.tensor
:return: mapped points and indeces for each channel
:rtype: tuple(torch.tensor, list)
"""
mapped_points = []
indeces_channels = []
for stride_idx, current_stride in enumerate(self._stride):
# indeces of points falling into filter range
indeces = index[stride_idx][batch_idx]
# how many points for each channel fall into the filter?
numb_points_insiede = torch.sum(indeces, dim=-1).tolist()
# extracting points for each channel
# shape: [sum(numb_points_insiede), filter_dim + 1]
point_stride = x[indeces]
# mapping points in filter domain
map_points_(point_stride[..., :-1], current_stride)
# extracting points for each channel
point_stride_channel = point_stride.split(numb_points_insiede)
# appending in list for later use
mapped_points.append(point_stride_channel)
indeces_channels.append(numb_points_insiede)
# stacking input for passing to neural net
mapping = map(torch.cat, zip(*mapped_points))
stacked_input = tuple(mapping)
indeces_channels = tuple(zip(*indeces_channels))
return stacked_input, indeces_channels
def _find_index(self, X):
"""Private method to extract indeces for convolution.
:param X: input tensor, as in ContinuousConv2D docstring
:type X: torch.tensor
"""
# append the index for each stride
index = []
for _, current_stride in enumerate(self._stride):
tmp = check_point(X, current_stride, self._dim)
index.append(tmp)
# storing the index
self._index = index
def _make_grid_forward(self, X):
"""Private method to create forward convolution grid.
:param X: input tensor, as in ContinuousConv2D docstring
:type X: torch.tensor
"""
# filter dimension + number of points in output grid
filter_dim = len(self._dim)
number_points = len(self._stride)
# initialize the grid
grid = torch.zeros(size=(X.shape[0],
self._output_numb_field,
number_points,
filter_dim + 1),
device=X.device,
dtype=X.dtype)
grid[..., :-1] = (self._stride + self._dim * 0.5)
# saving the grid
self._grid = grid.detach()
def _make_grid_transpose(self, X):
"""Private method to create transpose convolution grid.
:param X: input tensor, as in ContinuousConv2D docstring
:type X: torch.tensor
"""
# initialize to all zeros
tmp = torch.zeros_like(X)
tmp[..., :-1] = X[..., :-1]
# save on tmp
self._grid_transpose = tmp
def _make_grid(self, X, type):
"""Private method to create convolution grid.
:param X: input tensor, as in ContinuousConv2D docstring
:type X: torch.tensor
:param type: type of convolution, ['forward', 'inverse'] the
possibilities
:type type: string
"""
# choose the type of convolution
if type == 'forward':
return self._make_grid_forward(X)
elif type == 'inverse':
self._make_grid_transpose(X)
else:
raise TypeError
def _initialize_convolution(self, X, type='forward'):
"""Private method to intialize the convolution.
The convolution is initialized by setting a grid and
calculate the index for finding the points inside the
filter.
:param X: input tensor, as in ContinuousConv2D docstring
:type X: torch.tensor
:param type: type of convolution, ['forward', 'inverse'] the
possibilities
:type type: string
"""
# variable for the convolution
self._make_grid(X, type)
# calculate the index
self._find_index(X)
def forward(self, X):
"""Forward pass in the layer
:param x: input data (input_numb_field x N x filter_dim)
:type x: torch.tensor
:return: feed forward convolution (output_numb_field x N x filter_dim)
:rtype: torch.tensor
"""
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='forward')
else: # we always initialize on testing
self._initialize_convolution(X, 'forward')
# create convolutional array
conv = self._grid.clone().detach()
# total number of fields
tot_dim = self._output_numb_field * self._input_numb_field
for batch_idx, x in enumerate(X):
# extract mapped points
stacked_input, indeces_channels = self._extract_mapped_points(
batch_idx, self._index, x)
# compute the convolution
# storing intermidiate results for each channel convolution
res_tmp = []
# for each field
for idx_conv in range(tot_dim):
# index for each input field
idx = idx_conv % self._input_numb_field
# extract input for each channel
single_channel_input = stacked_input[idx]
# extract filter
net = self._net[idx_conv]
# calculate filter value
staked_output = net(single_channel_input[..., :-1])
# perform integral for all strides in one field
integral = self._integral(staked_output,
single_channel_input[..., -1],
indeces_channels[idx])
res_tmp.append(integral)
# stacking integral results
res_tmp = torch.stack(res_tmp)
# sum filters (for each input fields) in groups
# for different ouput fields
conv[batch_idx, ..., -1] = res_tmp.reshape(self._output_numb_field,
self._input_numb_field,
-1).sum(1)
return conv
def transpose_no_overlap(self, integrals, X):
"""Transpose pass in the layer for no-overlapping filters
:param integrals: Weights for the transpose convolution. Shape
[B x N_in x N]
where B is the batch_size, N_in is the number of input
fields, N the number of points in the mesh, D the dimension
of the problem.
:type integral: torch.tensor
:param X: Input data. Expect tensor of shape
[B x N_in x M x D] where B is the batch_size,
N_in is the number of input fields, M the number of points
in the mesh, D the dimension of the problem. Note, last column
:type X: torch.tensor
:return: Feed forward transpose convolution. Tensor of shape
[B x N_out x N] where B is the batch_size,
N_out is the number of output fields, N the number of points
in the mesh, D the dimension of the problem.
:rtype: torch.tensor
.. note::
This function is automatically called when `.transpose()`
method is used and `no_overlap=True`
"""
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='inverse')
else: # we always initialize on testing
self._initialize_convolution(X, 'inverse')
# initialize grid
X = self._grid_transpose.clone().detach()
conv_transposed = self._grid_transpose.clone().detach()
# total number of dim
tot_dim = self._input_numb_field * self._output_numb_field
for batch_idx, x in enumerate(X):
# extract mapped points
stacked_input, indeces_channels = self._extract_mapped_points(
batch_idx, self._index, x)
# compute the transpose convolution
# total number of fields
res_tmp = []
# for each field
for idx_conv in range(tot_dim):
# index for each output field
idx = idx_conv % self._output_numb_field
# index for each input field
idx_in = idx_conv % self._input_numb_field
# extract input for each field
single_channel_input = stacked_input[idx]
rep_idx = torch.tensor(indeces_channels[idx])
integral = integrals[batch_idx,
idx_in, :].repeat_interleave(rep_idx)
# extract filter
net = self._net[idx_conv]
# perform transpose convolution for all strides in one field
staked_output = net(single_channel_input[..., :-1]).flatten()
integral = staked_output * integral
res_tmp.append(integral)
# stacking integral results and sum
# filters (for each input fields) in groups
# for different output fields
res_tmp = torch.stack(res_tmp).reshape(self._input_numb_field,
self._output_numb_field,
-1).sum(0)
conv_transposed[batch_idx, ..., -1] = res_tmp
return conv_transposed
def transpose_overlap(self, integrals, X):
"""Transpose pass in the layer for overlapping filters
:param integrals: Weights for the transpose convolution. Shape
[B x N_in x N]
where B is the batch_size, N_in is the number of input
fields, N the number of points in the mesh, D the dimension
of the problem.
:type integral: torch.tensor
:param X: Input data. Expect tensor of shape
[B x N_in x M x D] where B is the batch_size,
N_in is the number of input fields, M the number of points
in the mesh, D the dimension of the problem. Note, last column
:type X: torch.tensor
:return: Feed forward transpose convolution. Tensor of shape
[B x N_out x N] where B is the batch_size,
N_out is the number of output fields, N the number of points
in the mesh, D the dimension of the problem.
:rtype: torch.tensor
.. note:: This function is automatically called when `.transpose()`
method is used and `no_overlap=False`
"""
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='inverse')
else: # we always initialize on testing
self._initialize_convolution(X, 'inverse')
# initialize grid
X = self._grid_transpose.clone().detach()
conv_transposed = self._grid_transpose.clone().detach()
# list to iterate for calculating nn output
tmp = [i for i in range(self._output_numb_field)]
iterate_conv = [item for item in tmp for _ in range(
self._input_numb_field)]
for batch_idx, x in enumerate(X):
# accumulator for the convolution on different batches
accumulator_batch = torch.zeros(
size=(self._grid_transpose.shape[1],
self._grid_transpose.shape[2]),
requires_grad=True,
device=X.device,
dtype=X.dtype).clone()
for stride_idx, current_stride in enumerate(self._stride):
# indeces of points falling into filter range
indeces = self._index[stride_idx][batch_idx]
# number of points for each channel
numb_pts_channel = tuple(indeces.sum(dim=-1))
# extracting points for each channel
point_stride = x[indeces]
# if no points to upsample we just skip
if point_stride.nelement() == 0:
continue
# mapping points in filter domain
map_points_(point_stride[..., :-1], current_stride)
# input points for kernels
# we split for extracting number of points for each channel
nn_input_pts = point_stride[..., :-1].split(numb_pts_channel)
# accumulate partial convolution results for each field
res_tmp = []
# for each channel field compute transpose convolution
for idx_conv, idx_channel_out in enumerate(iterate_conv):
# index for input channels
idx_channel_in = idx_conv % self._input_numb_field
# extract filter
net = self._net[idx_conv]
# calculate filter value
staked_output = net(nn_input_pts[idx_channel_out])
# perform integral for all strides in one field
integral = staked_output * integrals[batch_idx,
idx_channel_in,
stride_idx]
# append results
res_tmp.append(integral.flatten())
# computing channel sum
channel_sum = []
start = 0
for _ in range(self._output_numb_field):
tmp = res_tmp[start:start + self._input_numb_field]
tmp = torch.vstack(tmp).sum(dim=0)
channel_sum.append(tmp)
start += self._input_numb_field
# accumulate the results
accumulator_batch[indeces] += torch.hstack(channel_sum)
# save results of accumulation for each batch
conv_transposed[batch_idx, ..., -1] = accumulator_batch
return conv_transposed