Continuous Convolution (#69)
* network handling update * adding tutorial * docs
This commit is contained in:
7
pina/model/layers/__init__.py
Normal file
7
pina/model/layers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
__all__ = [
|
||||
'BaseContinuousConv',
|
||||
'ContinuousConv'
|
||||
]
|
||||
|
||||
from .convolution import BaseContinuousConv
|
||||
from .convolution_2d import ContinuousConv
|
||||
154
pina/model/layers/convolution.py
Normal file
154
pina/model/layers/convolution.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Module for Base Continuous Convolution class."""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from .stride import Stride
|
||||
from .utils_convolution import optimizing
|
||||
|
||||
|
||||
class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract class
|
||||
"""
|
||||
|
||||
def __init__(self, input_numb_field, output_numb_field,
|
||||
filter_dim, stride, model=None, optimize=False,
|
||||
no_overlap=False):
|
||||
"""Base Class for Continuous Convolution.
|
||||
|
||||
The algorithm expects input to be in the form:
|
||||
$$[B \times N_{in} \times N \times D]$$
|
||||
where $B$ is the batch_size, $N_{in}$ is the number of input
|
||||
fields, $N$ the number of points in the mesh, $D$ the dimension
|
||||
of the problem. In particular:
|
||||
* $D$ is the number of spatial variables + 1. The last column must
|
||||
contain the field value. For example for 2D problems $D=3$ and
|
||||
the tensor will be something like `[first coordinate, second
|
||||
coordinate, field value]`.
|
||||
* $N_{in}$ represents the number of vectorial function presented.
|
||||
For example a vectorial function $f = [f_1, f_2]$ will have
|
||||
$N_{in}=2$.
|
||||
|
||||
:Note
|
||||
A 2-dimensional vectorial function $N_{in}=2$ of 3-dimensional
|
||||
input $D=3+1=4$ with 100 points input mesh and batch size of 8
|
||||
is represented as a tensor `[8, 2, 100, 4]`, where the columns
|
||||
`[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
|
||||
second filed value respectively
|
||||
|
||||
The algorithm returns a tensor of shape:
|
||||
$$[B \times N_{out} \times N' \times D]$$
|
||||
where $B$ is the batch_size, $N_{out}$ is the number of output
|
||||
fields, $N'$ the number of points in the mesh, $D$ the dimension
|
||||
of the problem.
|
||||
|
||||
:param input_numb_field: number of fields in the input
|
||||
:type input_numb_field: int
|
||||
:param output_numb_field: number of fields in the output
|
||||
:type output_numb_field: int
|
||||
:param filter_dim: dimension of the filter
|
||||
:type filter_dim: tuple/ list
|
||||
:param stride: stride for the filter
|
||||
:type stride: dict
|
||||
:param model: neural network for inner parametrization,
|
||||
defaults to None
|
||||
:type model: torch.nn.Module, optional
|
||||
:param optimize: flag for performing optimization on the continuous
|
||||
filter, defaults to False. The flag `optimize=True` should be
|
||||
used only when the scatter datapoints are fixed through the
|
||||
training. If torch model is in `.eval()` mode, the flag is
|
||||
automatically set to False always.
|
||||
:type optimize: bool, optional
|
||||
:param no_overlap: flag for performing optimization on the transpose
|
||||
continuous filter, defaults to False. The flag set to `True` should
|
||||
be used only when the filter positions do not overlap for different
|
||||
strides. RuntimeError will raise in case of non-compatible strides.
|
||||
:type no_overlap: bool, optional
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(input_numb_field, int):
|
||||
self._input_numb_field = input_numb_field
|
||||
else:
|
||||
raise ValueError('input_numb_field must be int.')
|
||||
|
||||
if isinstance(output_numb_field, int):
|
||||
self._output_numb_field = output_numb_field
|
||||
else:
|
||||
raise ValueError('input_numb_field must be int.')
|
||||
|
||||
if isinstance(filter_dim, (tuple, list)):
|
||||
vect = filter_dim
|
||||
else:
|
||||
raise ValueError('filter_dim must be tuple or list.')
|
||||
vect = torch.tensor(vect)
|
||||
self.register_buffer("_dim", vect, persistent=False)
|
||||
|
||||
if isinstance(stride, dict):
|
||||
self._stride = Stride(stride)
|
||||
else:
|
||||
raise ValueError('stride must be dictionary.')
|
||||
|
||||
self._net = model
|
||||
|
||||
if isinstance(optimize, bool):
|
||||
self._optimize = optimize
|
||||
else:
|
||||
raise ValueError('optimize must be bool.')
|
||||
|
||||
# choosing how to initialize based on optimization
|
||||
if self._optimize:
|
||||
# optimizing decorator ensure the function is called
|
||||
# just once
|
||||
self._choose_initialization = optimizing(
|
||||
self._initialize_convolution)
|
||||
else:
|
||||
self._choose_initialization = self._initialize_convolution
|
||||
|
||||
if not isinstance(no_overlap, bool):
|
||||
raise ValueError('no_overlap must be bool.')
|
||||
|
||||
if no_overlap:
|
||||
raise NotImplementedError
|
||||
self.transpose = self.transpose_no_overlap
|
||||
else:
|
||||
self.transpose = self.transpose_overlap
|
||||
|
||||
@ property
|
||||
def net(self):
|
||||
return self._net
|
||||
|
||||
@ property
|
||||
def stride(self):
|
||||
return self._stride
|
||||
|
||||
@ property
|
||||
def dim(self):
|
||||
return self._dim
|
||||
|
||||
@ property
|
||||
def input_numb_field(self):
|
||||
return self._input_numb_field
|
||||
|
||||
@ property
|
||||
def output_numb_field(self):
|
||||
return self._output_numb_field
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def forward(self, X):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def transpose_overlap(self, X):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def transpose_no_overlap(self, X):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _initialize_convolution(self, X, type):
|
||||
pass
|
||||
548
pina/model/layers/convolution_2d.py
Normal file
548
pina/model/layers/convolution_2d.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""Module for Continuous Convolution class"""
|
||||
from .convolution import BaseContinuousConv
|
||||
from .utils_convolution import check_point, map_points_
|
||||
from .integral import Integral
|
||||
from ..feed_forward import FeedForward
|
||||
import torch
|
||||
|
||||
|
||||
class ContinuousConv(BaseContinuousConv):
|
||||
"""
|
||||
Implementation of Continuous Convolutional operator.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Coscia, D., Meneghetti, L., Demo, N.,
|
||||
Stabile, G., & Rozza, G.. (2022). A Continuous Convolutional Trainable
|
||||
Filter for Modelling Unstructured Data.
|
||||
DOI: `10.48550/arXiv.2210.13416
|
||||
<https://doi.org/10.48550/arXiv.2210.13416>`_.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_numb_field, output_numb_field,
|
||||
filter_dim, stride, model=None, optimize=False,
|
||||
no_overlap=False):
|
||||
"""
|
||||
|
||||
:param input_numb_field: Number of fields N_in in the input.
|
||||
:type input_numb_field: int
|
||||
:param output_numb_field: Number of fields N_out in the output.
|
||||
:type output_numb_field: int
|
||||
:param filter_dim: Dimension of the filter.
|
||||
:type filter_dim: tuple/ list
|
||||
:param stride: Stride for the filter.
|
||||
:type stride: dict
|
||||
:param model: Neural network for inner parametrization,
|
||||
defaults to None. If None, pina.FeedForward is used, more
|
||||
on https://mathlab.github.io/PINA/_rst/fnn.html.
|
||||
:type model: torch.nn.Module, optional
|
||||
:param optimize: Flag for performing optimization on the continuous
|
||||
filter, defaults to False. The flag `optimize=True` should be
|
||||
used only when the scatter datapoints are fixed through the
|
||||
training. If torch model is in `.eval()` mode, the flag is
|
||||
automatically set to False always.
|
||||
:type optimize: bool, optional
|
||||
:param no_overlap: Flag for performing optimization on the transpose
|
||||
continuous filter, defaults to False. The flag set to `True` should
|
||||
be used only when the filter positions do not overlap for different
|
||||
strides. RuntimeError will raise in case of non-compatible strides.
|
||||
:type no_overlap: bool, optional
|
||||
|
||||
.. note::
|
||||
Using `optimize=True` the filter can be use either in `forward`
|
||||
or in `transpose` mode, not both. If `optimize=False` the same
|
||||
filter can be used for both `transpose` and `forward` modes.
|
||||
|
||||
.. warning::
|
||||
The algorithm expects input to be in the form: [B x N_in x N x D]
|
||||
where B is the batch_size, N_in is the number of input
|
||||
fields, N the number of points in the mesh, D the dimension
|
||||
of the problem. In particular:
|
||||
|
||||
* D is the number of spatial variables + 1. The last column must
|
||||
contain the field value. For example for 2D problems D=3 and
|
||||
the tensor will be something like `[first coordinate, second
|
||||
coordinate, field value]`.
|
||||
|
||||
* N_in represents the number of vectorial function presented.
|
||||
For example a vectorial function f = [f_1, f_2] will have
|
||||
N_in=2.
|
||||
|
||||
The algorithm returns a tensor of shape: [B x N_out x N x D]
|
||||
where B is the batch_size, N_out is the number of output
|
||||
fields, N' the number of points in the mesh, D the dimension
|
||||
of the problem (coordinates + field value).
|
||||
|
||||
For example, a 2-dimensional vectorial function N_in=2 of
|
||||
3-dimensionalcinput D=3+1=4 with 100 points input mesh and batch
|
||||
size of 8 is represented as a tensor `[8, 2, 100, 4]`, where the
|
||||
columnsc`[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
|
||||
second filed value respectively.
|
||||
|
||||
:Example:
|
||||
>>> class MLP(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self. model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 8),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(8, 8),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(8, 1))
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
>>> dim = [3, 3]
|
||||
>>> stride = {"domain": [10, 10],
|
||||
"start": [0, 0],
|
||||
"jumps": [3, 3],
|
||||
"direction": [1, 1.]}
|
||||
>>> conv = ContinuousConv2D(1, 2, dim, stride, MLP)
|
||||
>>> conv
|
||||
ContinuousConv2D(
|
||||
(_net): ModuleList(
|
||||
(0): MLP(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=2, out_features=8, bias=True)
|
||||
(1): ReLU()
|
||||
(2): Linear(in_features=8, out_features=8, bias=True)
|
||||
(3): ReLU()
|
||||
(4): Linear(in_features=8, out_features=1, bias=True)
|
||||
)
|
||||
)
|
||||
(1): MLP(
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=2, out_features=8, bias=True)
|
||||
(1): ReLU()
|
||||
(2): Linear(in_features=8, out_features=8, bias=True)
|
||||
(3): ReLU()
|
||||
(4): Linear(in_features=8, out_features=1, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
super().__init__(input_numb_field=input_numb_field,
|
||||
output_numb_field=output_numb_field,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
model=model,
|
||||
optimize=optimize,
|
||||
no_overlap=no_overlap)
|
||||
|
||||
# integral routine
|
||||
self._integral = Integral('discrete')
|
||||
|
||||
# create the network
|
||||
self._net = self._spawn_networks(model)
|
||||
|
||||
# stride for continuous convolution overridden
|
||||
self._stride = self._stride._stride_discrete
|
||||
|
||||
def _spawn_networks(self, model):
|
||||
"""Private method to create a collection of kernels
|
||||
|
||||
:param model: a torch.nn.Module model in form of Object class
|
||||
:type model: torch.nn.Module
|
||||
:return: list of torch.nn.Module models
|
||||
:rtype: torch.nn.ModuleList
|
||||
|
||||
"""
|
||||
nets = []
|
||||
if self._net is None:
|
||||
for _ in range(self._input_numb_field * self._output_numb_field):
|
||||
tmp = FeedForward(len(self._dim), 1)
|
||||
nets.append(tmp)
|
||||
else:
|
||||
if not isinstance(model, object):
|
||||
raise ValueError("Expected a python class inheriting"
|
||||
" from torch.nn.Module")
|
||||
|
||||
for _ in range(self._input_numb_field * self._output_numb_field):
|
||||
tmp = model()
|
||||
if not isinstance(tmp, torch.nn.Module):
|
||||
raise ValueError("The python class must be inherited from"
|
||||
" torch.nn.Module. See the docstring for"
|
||||
" an example.")
|
||||
nets.append(tmp)
|
||||
|
||||
return torch.nn.ModuleList(nets)
|
||||
|
||||
def _extract_mapped_points(self, batch_idx, index, x):
|
||||
"""Priviate method to extract mapped points in the filter
|
||||
|
||||
:param x: input tensor [channel x N x dim]
|
||||
:type x: torch.tensor
|
||||
:return: mapped points and indeces for each channel
|
||||
:rtype: tuple(torch.tensor, list)
|
||||
|
||||
"""
|
||||
mapped_points = []
|
||||
indeces_channels = []
|
||||
|
||||
for stride_idx, current_stride in enumerate(self._stride):
|
||||
|
||||
# indeces of points falling into filter range
|
||||
indeces = index[stride_idx][batch_idx]
|
||||
|
||||
# how many points for each channel fall into the filter?
|
||||
numb_points_insiede = torch.sum(indeces, dim=-1).tolist()
|
||||
|
||||
# extracting points for each channel
|
||||
# shape: [sum(numb_points_insiede), filter_dim + 1]
|
||||
point_stride = x[indeces]
|
||||
|
||||
# mapping points in filter domain
|
||||
map_points_(point_stride[..., :-1], current_stride)
|
||||
|
||||
# extracting points for each channel
|
||||
point_stride_channel = point_stride.split(numb_points_insiede)
|
||||
|
||||
# appending in list for later use
|
||||
mapped_points.append(point_stride_channel)
|
||||
indeces_channels.append(numb_points_insiede)
|
||||
|
||||
# stacking input for passing to neural net
|
||||
mapping = map(torch.cat, zip(*mapped_points))
|
||||
stacked_input = tuple(mapping)
|
||||
indeces_channels = tuple(zip(*indeces_channels))
|
||||
|
||||
return stacked_input, indeces_channels
|
||||
|
||||
def _find_index(self, X):
|
||||
"""Private method to extract indeces for convolution.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
|
||||
"""
|
||||
# append the index for each stride
|
||||
index = []
|
||||
for _, current_stride in enumerate(self._stride):
|
||||
|
||||
tmp = check_point(X, current_stride, self._dim)
|
||||
index.append(tmp)
|
||||
|
||||
# storing the index
|
||||
self._index = index
|
||||
|
||||
def _make_grid_forward(self, X):
|
||||
"""Private method to create forward convolution grid.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
|
||||
"""
|
||||
# filter dimension + number of points in output grid
|
||||
filter_dim = len(self._dim)
|
||||
number_points = len(self._stride)
|
||||
|
||||
# initialize the grid
|
||||
grid = torch.zeros(size=(X.shape[0],
|
||||
self._output_numb_field,
|
||||
number_points,
|
||||
filter_dim + 1),
|
||||
device=X.device,
|
||||
dtype=X.dtype)
|
||||
grid[..., :-1] = (self._stride + self._dim * 0.5)
|
||||
|
||||
# saving the grid
|
||||
self._grid = grid.detach()
|
||||
|
||||
def _make_grid_transpose(self, X):
|
||||
"""Private method to create transpose convolution grid.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
|
||||
"""
|
||||
# initialize to all zeros
|
||||
tmp = torch.zeros_like(X)
|
||||
tmp[..., :-1] = X[..., :-1]
|
||||
|
||||
# save on tmp
|
||||
self._grid_transpose = tmp
|
||||
|
||||
def _make_grid(self, X, type):
|
||||
"""Private method to create convolution grid.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
:param type: type of convolution, ['forward', 'inverse'] the
|
||||
possibilities
|
||||
:type type: string
|
||||
|
||||
"""
|
||||
# choose the type of convolution
|
||||
if type == 'forward':
|
||||
return self._make_grid_forward(X)
|
||||
elif type == 'inverse':
|
||||
self._make_grid_transpose(X)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def _initialize_convolution(self, X, type='forward'):
|
||||
"""Private method to intialize the convolution.
|
||||
The convolution is initialized by setting a grid and
|
||||
calculate the index for finding the points inside the
|
||||
filter.
|
||||
|
||||
:param X: input tensor, as in ContinuousConv2D docstring
|
||||
:type X: torch.tensor
|
||||
:param type: type of convolution, ['forward', 'inverse'] the
|
||||
possibilities
|
||||
:type type: string
|
||||
"""
|
||||
|
||||
# variable for the convolution
|
||||
self._make_grid(X, type)
|
||||
|
||||
# calculate the index
|
||||
self._find_index(X)
|
||||
|
||||
def forward(self, X):
|
||||
"""Forward pass in the layer
|
||||
|
||||
:param x: input data (input_numb_field x N x filter_dim)
|
||||
:type x: torch.tensor
|
||||
:return: feed forward convolution (output_numb_field x N x filter_dim)
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='forward')
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'forward')
|
||||
|
||||
# create convolutional array
|
||||
conv = self._grid.clone().detach()
|
||||
|
||||
# total number of fields
|
||||
tot_dim = self._output_numb_field * self._input_numb_field
|
||||
|
||||
for batch_idx, x in enumerate(X):
|
||||
|
||||
# extract mapped points
|
||||
stacked_input, indeces_channels = self._extract_mapped_points(
|
||||
batch_idx, self._index, x)
|
||||
|
||||
# compute the convolution
|
||||
|
||||
# storing intermidiate results for each channel convolution
|
||||
res_tmp = []
|
||||
# for each field
|
||||
for idx_conv in range(tot_dim):
|
||||
# index for each input field
|
||||
idx = idx_conv % self._input_numb_field
|
||||
# extract input for each channel
|
||||
single_channel_input = stacked_input[idx]
|
||||
# extract filter
|
||||
net = self._net[idx_conv]
|
||||
# calculate filter value
|
||||
staked_output = net(single_channel_input[..., :-1])
|
||||
# perform integral for all strides in one field
|
||||
integral = self._integral(staked_output,
|
||||
single_channel_input[..., -1],
|
||||
indeces_channels[idx])
|
||||
res_tmp.append(integral)
|
||||
|
||||
# stacking integral results
|
||||
res_tmp = torch.stack(res_tmp)
|
||||
|
||||
# sum filters (for each input fields) in groups
|
||||
# for different ouput fields
|
||||
conv[batch_idx, ..., -1] = res_tmp.reshape(self._output_numb_field,
|
||||
self._input_numb_field,
|
||||
-1).sum(1)
|
||||
return conv
|
||||
|
||||
def transpose_no_overlap(self, integrals, X):
|
||||
"""Transpose pass in the layer for no-overlapping filters
|
||||
|
||||
:param integrals: Weights for the transpose convolution. Shape
|
||||
[B x N_in x N]
|
||||
where B is the batch_size, N_in is the number of input
|
||||
fields, N the number of points in the mesh, D the dimension
|
||||
of the problem.
|
||||
:type integral: torch.tensor
|
||||
:param X: Input data. Expect tensor of shape
|
||||
[B x N_in x M x D] where B is the batch_size,
|
||||
N_in is the number of input fields, M the number of points
|
||||
in the mesh, D the dimension of the problem. Note, last column
|
||||
:type X: torch.tensor
|
||||
:return: Feed forward transpose convolution. Tensor of shape
|
||||
[B x N_out x N] where B is the batch_size,
|
||||
N_out is the number of output fields, N the number of points
|
||||
in the mesh, D the dimension of the problem.
|
||||
:rtype: torch.tensor
|
||||
|
||||
.. note::
|
||||
This function is automatically called when `.transpose()`
|
||||
method is used and `no_overlap=True`
|
||||
"""
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='inverse')
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'inverse')
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
conv_transposed = self._grid_transpose.clone().detach()
|
||||
|
||||
# total number of dim
|
||||
tot_dim = self._input_numb_field * self._output_numb_field
|
||||
|
||||
for batch_idx, x in enumerate(X):
|
||||
|
||||
# extract mapped points
|
||||
stacked_input, indeces_channels = self._extract_mapped_points(
|
||||
batch_idx, self._index, x)
|
||||
|
||||
# compute the transpose convolution
|
||||
|
||||
# total number of fields
|
||||
res_tmp = []
|
||||
|
||||
# for each field
|
||||
for idx_conv in range(tot_dim):
|
||||
# index for each output field
|
||||
idx = idx_conv % self._output_numb_field
|
||||
# index for each input field
|
||||
idx_in = idx_conv % self._input_numb_field
|
||||
# extract input for each field
|
||||
single_channel_input = stacked_input[idx]
|
||||
rep_idx = torch.tensor(indeces_channels[idx])
|
||||
integral = integrals[batch_idx,
|
||||
idx_in, :].repeat_interleave(rep_idx)
|
||||
# extract filter
|
||||
net = self._net[idx_conv]
|
||||
# perform transpose convolution for all strides in one field
|
||||
staked_output = net(single_channel_input[..., :-1]).flatten()
|
||||
integral = staked_output * integral
|
||||
res_tmp.append(integral)
|
||||
|
||||
# stacking integral results and sum
|
||||
# filters (for each input fields) in groups
|
||||
# for different output fields
|
||||
res_tmp = torch.stack(res_tmp).reshape(self._input_numb_field,
|
||||
self._output_numb_field,
|
||||
-1).sum(0)
|
||||
conv_transposed[batch_idx, ..., -1] = res_tmp
|
||||
|
||||
return conv_transposed
|
||||
|
||||
def transpose_overlap(self, integrals, X):
|
||||
"""Transpose pass in the layer for overlapping filters
|
||||
|
||||
:param integrals: Weights for the transpose convolution. Shape
|
||||
[B x N_in x N]
|
||||
where B is the batch_size, N_in is the number of input
|
||||
fields, N the number of points in the mesh, D the dimension
|
||||
of the problem.
|
||||
:type integral: torch.tensor
|
||||
:param X: Input data. Expect tensor of shape
|
||||
[B x N_in x M x D] where B is the batch_size,
|
||||
N_in is the number of input fields, M the number of points
|
||||
in the mesh, D the dimension of the problem. Note, last column
|
||||
:type X: torch.tensor
|
||||
:return: Feed forward transpose convolution. Tensor of shape
|
||||
[B x N_out x N] where B is the batch_size,
|
||||
N_out is the number of output fields, N the number of points
|
||||
in the mesh, D the dimension of the problem.
|
||||
:rtype: torch.tensor
|
||||
|
||||
.. note:: This function is automatically called when `.transpose()`
|
||||
method is used and `no_overlap=False`
|
||||
"""
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='inverse')
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'inverse')
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
conv_transposed = self._grid_transpose.clone().detach()
|
||||
|
||||
# list to iterate for calculating nn output
|
||||
tmp = [i for i in range(self._output_numb_field)]
|
||||
iterate_conv = [item for item in tmp for _ in range(
|
||||
self._input_numb_field)]
|
||||
|
||||
for batch_idx, x in enumerate(X):
|
||||
|
||||
# accumulator for the convolution on different batches
|
||||
accumulator_batch = torch.zeros(
|
||||
size=(self._grid_transpose.shape[1],
|
||||
self._grid_transpose.shape[2]),
|
||||
requires_grad=True,
|
||||
device=X.device,
|
||||
dtype=X.dtype).clone()
|
||||
|
||||
for stride_idx, current_stride in enumerate(self._stride):
|
||||
# indeces of points falling into filter range
|
||||
indeces = self._index[stride_idx][batch_idx]
|
||||
|
||||
# number of points for each channel
|
||||
numb_pts_channel = tuple(indeces.sum(dim=-1))
|
||||
|
||||
# extracting points for each channel
|
||||
point_stride = x[indeces]
|
||||
|
||||
# if no points to upsample we just skip
|
||||
if point_stride.nelement() == 0:
|
||||
continue
|
||||
|
||||
# mapping points in filter domain
|
||||
map_points_(point_stride[..., :-1], current_stride)
|
||||
|
||||
# input points for kernels
|
||||
# we split for extracting number of points for each channel
|
||||
nn_input_pts = point_stride[..., :-1].split(numb_pts_channel)
|
||||
|
||||
# accumulate partial convolution results for each field
|
||||
res_tmp = []
|
||||
|
||||
# for each channel field compute transpose convolution
|
||||
for idx_conv, idx_channel_out in enumerate(iterate_conv):
|
||||
|
||||
# index for input channels
|
||||
idx_channel_in = idx_conv % self._input_numb_field
|
||||
|
||||
# extract filter
|
||||
net = self._net[idx_conv]
|
||||
|
||||
# calculate filter value
|
||||
staked_output = net(nn_input_pts[idx_channel_out])
|
||||
|
||||
# perform integral for all strides in one field
|
||||
integral = staked_output * integrals[batch_idx,
|
||||
idx_channel_in,
|
||||
stride_idx]
|
||||
# append results
|
||||
res_tmp.append(integral.flatten())
|
||||
|
||||
# computing channel sum
|
||||
channel_sum = []
|
||||
start = 0
|
||||
for _ in range(self._output_numb_field):
|
||||
tmp = res_tmp[start:start + self._input_numb_field]
|
||||
tmp = torch.vstack(tmp).sum(dim=0)
|
||||
channel_sum.append(tmp)
|
||||
start += self._input_numb_field
|
||||
|
||||
# accumulate the results
|
||||
accumulator_batch[indeces] += torch.hstack(channel_sum)
|
||||
|
||||
# save results of accumulation for each batch
|
||||
conv_transposed[batch_idx, ..., -1] = accumulator_batch
|
||||
|
||||
return conv_transposed
|
||||
63
pina/model/layers/integral.py
Normal file
63
pina/model/layers/integral.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Integral(object):
|
||||
|
||||
def __init__(self, param):
|
||||
"""Integral class for continous convolution
|
||||
|
||||
:param param: type of continuous convolution
|
||||
:type param: string
|
||||
"""
|
||||
|
||||
if param == 'discrete':
|
||||
self.make_integral = self.integral_param_disc
|
||||
elif param == 'continuous':
|
||||
self.make_integral = self.integral_param_cont
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
return self.make_integral(*args, **kwds)
|
||||
|
||||
def _prepend_zero(self, x):
|
||||
"""Create bins for performing integral
|
||||
|
||||
:param x: input tensor
|
||||
:type x: torch.tensor
|
||||
:return: bins for integrals
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
return torch.cat((torch.zeros(1, dtype=x.dtype, device=x.device), x))
|
||||
|
||||
def integral_param_disc(self, x, y, idx):
|
||||
"""Perform discretize integral
|
||||
with discrete parameters
|
||||
|
||||
:param x: input vector
|
||||
:type x: torch.tensor
|
||||
:param y: input vector
|
||||
:type y: torch.tensor
|
||||
:param idx: indeces for different strides
|
||||
:type idx: list
|
||||
:return: integral
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
cs_idxes = self._prepend_zero(torch.cumsum(torch.tensor(idx), 0))
|
||||
cs = self._prepend_zero(torch.cumsum(x.flatten() * y.flatten(), 0))
|
||||
return cs[cs_idxes[1:]] - cs[cs_idxes[:-1]]
|
||||
|
||||
def integral_param_cont(self, x, y, idx):
|
||||
"""Perform discretize integral for continuous convolution
|
||||
with continuous parameters
|
||||
|
||||
:param x: input vector
|
||||
:type x: torch.tensor
|
||||
:param y: input vector
|
||||
:type y: torch.tensor
|
||||
:param idx: indeces for different strides
|
||||
:type idx: list
|
||||
:return: integral
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
82
pina/model/layers/stride.py
Normal file
82
pina/model/layers/stride.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Stride(object):
|
||||
|
||||
def __init__(self, dict):
|
||||
"""Stride class for continous convolution
|
||||
|
||||
:param param: type of continuous convolution
|
||||
:type param: string
|
||||
"""
|
||||
|
||||
self._dict_stride = dict
|
||||
self._stride_continuous = None
|
||||
self._stride_discrete = self._create_stride_discrete(dict)
|
||||
|
||||
def _create_stride_discrete(self, my_dict):
|
||||
"""Creating the list for applying the filter
|
||||
|
||||
:param my_dict: Dictionary with the following arguments:
|
||||
domain size, starting position of the filter, jump size
|
||||
for the filter and direction of the filter
|
||||
:type my_dict: dict
|
||||
:raises IndexError: Values in the dict must have all same length
|
||||
:raises ValueError: Domain values must be greater than 0
|
||||
:raises ValueError: Direction must be either equal to 1, -1 or 0
|
||||
:raises IndexError: Direction and jumps must have zero in the same
|
||||
index
|
||||
:return: list of positions for the filter
|
||||
:rtype: list
|
||||
:Example:
|
||||
|
||||
|
||||
>>> stride = {"domain": [4, 4],
|
||||
"start": [-4, 2],
|
||||
"jump": [2, 2],
|
||||
"direction": [1, 1],
|
||||
}
|
||||
>>> create_stride(stride)
|
||||
[[-4.0, 2.0], [-4.0, 4.0], [-2.0, 2.0], [-2.0, 4.0]]
|
||||
"""
|
||||
|
||||
# we must check boundaries of the input as well
|
||||
|
||||
domain, start, jumps, direction = my_dict.values()
|
||||
|
||||
# checking
|
||||
|
||||
if not all([len(s) == len(domain) for s in my_dict.values()]):
|
||||
raise IndexError("values in the dict must have all same length")
|
||||
|
||||
if not all(v >= 0 for v in domain):
|
||||
raise ValueError("domain values must be greater than 0")
|
||||
|
||||
if not all(v == 1 or v == -1 or v == 0 for v in direction):
|
||||
raise ValueError("direction must be either equal to 1, -1 or 0")
|
||||
|
||||
seq_jumps = [i for i, e in enumerate(jumps) if e == 0]
|
||||
seq_direction = [i for i, e in enumerate(direction) if e == 0]
|
||||
|
||||
if seq_direction != seq_jumps:
|
||||
raise IndexError(
|
||||
"direction and jumps must have zero in the same index")
|
||||
|
||||
if seq_jumps:
|
||||
for i in seq_jumps:
|
||||
jumps[i] = domain[i]
|
||||
direction[i] = 1
|
||||
|
||||
# creating the stride grid
|
||||
values_mesh = [torch.arange(0, i, step).float()
|
||||
for i, step in zip(domain, jumps)]
|
||||
|
||||
values_mesh = [single * dim for single,
|
||||
dim in zip(values_mesh, direction)]
|
||||
|
||||
mesh = torch.meshgrid(values_mesh)
|
||||
coordinates_mesh = [x.reshape(-1, 1) for x in mesh]
|
||||
|
||||
stride = torch.cat(coordinates_mesh, dim=1) + torch.tensor(start)
|
||||
|
||||
return stride
|
||||
48
pina/model/layers/utils_convolution.py
Normal file
48
pina/model/layers/utils_convolution.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
|
||||
|
||||
def check_point(x, current_stride, dim):
|
||||
max_stride = current_stride + dim
|
||||
indeces = torch.logical_and(x[..., :-1] < max_stride,
|
||||
x[..., :-1] >= current_stride).all(dim=-1)
|
||||
return indeces
|
||||
|
||||
|
||||
def map_points_(x, filter_position):
|
||||
"""Mapping function n dimensional case
|
||||
|
||||
:param x: input data of two dimension
|
||||
:type x: torch.tensor
|
||||
:param filter_position: position of the filter
|
||||
:type dim: list[numeric]
|
||||
:return: data mapped inplace
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
x.add_(-filter_position)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def optimizing(f):
|
||||
"""Decorator for calling a function just once
|
||||
|
||||
:param f: python function
|
||||
:type f: function
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
if kwargs['type'] == 'forward':
|
||||
if not wrapper.has_run_inverse:
|
||||
wrapper.has_run_inverse = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if kwargs['type'] == 'inverse':
|
||||
if not wrapper.has_run:
|
||||
wrapper.has_run = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
wrapper.has_run_inverse = False
|
||||
wrapper.has_run = False
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user