🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Module for Continuous Convolution class"""
|
||||
|
||||
from .convolution import BaseContinuousConv
|
||||
from .utils_convolution import check_point, map_points_
|
||||
from .integral import Integral
|
||||
@@ -31,14 +32,16 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_numb_field,
|
||||
output_numb_field,
|
||||
filter_dim,
|
||||
stride,
|
||||
model=None,
|
||||
optimize=False,
|
||||
no_overlap=False):
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_field,
|
||||
output_numb_field,
|
||||
filter_dim,
|
||||
stride,
|
||||
model=None,
|
||||
optimize=False,
|
||||
no_overlap=False,
|
||||
):
|
||||
"""
|
||||
:param input_numb_field: Number of fields :math:`N_{in}` in the input.
|
||||
:type input_numb_field: int
|
||||
@@ -112,16 +115,18 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
)
|
||||
"""
|
||||
|
||||
super().__init__(input_numb_field=input_numb_field,
|
||||
output_numb_field=output_numb_field,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
model=model,
|
||||
optimize=optimize,
|
||||
no_overlap=no_overlap)
|
||||
super().__init__(
|
||||
input_numb_field=input_numb_field,
|
||||
output_numb_field=output_numb_field,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
model=model,
|
||||
optimize=optimize,
|
||||
no_overlap=no_overlap,
|
||||
)
|
||||
|
||||
# integral routine
|
||||
self._integral = Integral('discrete')
|
||||
self._integral = Integral("discrete")
|
||||
|
||||
# create the network
|
||||
self._net = self._spawn_networks(model)
|
||||
@@ -146,15 +151,18 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
nets.append(tmp)
|
||||
else:
|
||||
if not isinstance(model, object):
|
||||
raise ValueError("Expected a python class inheriting"
|
||||
" from torch.nn.Module")
|
||||
raise ValueError(
|
||||
"Expected a python class inheriting" " from torch.nn.Module"
|
||||
)
|
||||
|
||||
for _ in range(self._input_numb_field * self._output_numb_field):
|
||||
tmp = model()
|
||||
if not isinstance(tmp, torch.nn.Module):
|
||||
raise ValueError("The python class must be inherited from"
|
||||
" torch.nn.Module. See the docstring for"
|
||||
" an example.")
|
||||
raise ValueError(
|
||||
"The python class must be inherited from"
|
||||
" torch.nn.Module. See the docstring for"
|
||||
" an example."
|
||||
)
|
||||
nets.append(tmp)
|
||||
|
||||
return torch.nn.ModuleList(nets)
|
||||
@@ -232,11 +240,17 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
number_points = len(self._stride)
|
||||
|
||||
# initialize the grid
|
||||
grid = torch.zeros(size=(X.shape[0], self._output_numb_field,
|
||||
number_points, filter_dim + 1),
|
||||
device=X.device,
|
||||
dtype=X.dtype)
|
||||
grid[..., :-1] = (self._stride + self._dim * 0.5)
|
||||
grid = torch.zeros(
|
||||
size=(
|
||||
X.shape[0],
|
||||
self._output_numb_field,
|
||||
number_points,
|
||||
filter_dim + 1,
|
||||
),
|
||||
device=X.device,
|
||||
dtype=X.dtype,
|
||||
)
|
||||
grid[..., :-1] = self._stride + self._dim * 0.5
|
||||
|
||||
# saving the grid
|
||||
self._grid = grid.detach()
|
||||
@@ -269,14 +283,14 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
"""
|
||||
# choose the type of convolution
|
||||
if type == 'forward':
|
||||
if type == "forward":
|
||||
return self._make_grid_forward(X)
|
||||
elif type == 'inverse':
|
||||
elif type == "inverse":
|
||||
self._make_grid_transpose(X)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def _initialize_convolution(self, X, type='forward'):
|
||||
def _initialize_convolution(self, X, type="forward"):
|
||||
"""
|
||||
Private method to intialize the convolution.
|
||||
The convolution is initialized by setting a grid and
|
||||
@@ -307,10 +321,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='forward')
|
||||
self._choose_initialization(X, type="forward")
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'forward')
|
||||
self._initialize_convolution(X, "forward")
|
||||
|
||||
# create convolutional array
|
||||
conv = self._grid.clone().detach()
|
||||
@@ -322,7 +336,8 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# extract mapped points
|
||||
stacked_input, indeces_channels = self._extract_mapped_points(
|
||||
batch_idx, self._index, x)
|
||||
batch_idx, self._index, x
|
||||
)
|
||||
|
||||
# compute the convolution
|
||||
|
||||
@@ -339,9 +354,11 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
# calculate filter value
|
||||
staked_output = net(single_channel_input[..., :-1])
|
||||
# perform integral for all strides in one field
|
||||
integral = self._integral(staked_output,
|
||||
single_channel_input[..., -1],
|
||||
indeces_channels[idx])
|
||||
integral = self._integral(
|
||||
staked_output,
|
||||
single_channel_input[..., -1],
|
||||
indeces_channels[idx],
|
||||
)
|
||||
res_tmp.append(integral)
|
||||
|
||||
# stacking integral results
|
||||
@@ -349,9 +366,9 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# sum filters (for each input fields) in groups
|
||||
# for different ouput fields
|
||||
conv[batch_idx, ...,
|
||||
-1] = res_tmp.reshape(self._output_numb_field,
|
||||
self._input_numb_field, -1).sum(1)
|
||||
conv[batch_idx, ..., -1] = res_tmp.reshape(
|
||||
self._output_numb_field, self._input_numb_field, -1
|
||||
).sum(1)
|
||||
return conv
|
||||
|
||||
def transpose_no_overlap(self, integrals, X):
|
||||
@@ -382,10 +399,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='inverse')
|
||||
self._choose_initialization(X, type="inverse")
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'inverse')
|
||||
self._initialize_convolution(X, "inverse")
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
@@ -398,7 +415,8 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# extract mapped points
|
||||
stacked_input, indeces_channels = self._extract_mapped_points(
|
||||
batch_idx, self._index, x)
|
||||
batch_idx, self._index, x
|
||||
)
|
||||
|
||||
# compute the transpose convolution
|
||||
|
||||
@@ -414,8 +432,9 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
# extract input for each field
|
||||
single_channel_input = stacked_input[idx]
|
||||
rep_idx = torch.tensor(indeces_channels[idx])
|
||||
integral = integrals[batch_idx,
|
||||
idx_in, :].repeat_interleave(rep_idx)
|
||||
integral = integrals[batch_idx, idx_in, :].repeat_interleave(
|
||||
rep_idx
|
||||
)
|
||||
# extract filter
|
||||
net = self._net[idx_conv]
|
||||
# perform transpose convolution for all strides in one field
|
||||
@@ -426,9 +445,11 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
# stacking integral results and sum
|
||||
# filters (for each input fields) in groups
|
||||
# for different output fields
|
||||
res_tmp = torch.stack(res_tmp).reshape(self._input_numb_field,
|
||||
self._output_numb_field,
|
||||
-1).sum(0)
|
||||
res_tmp = (
|
||||
torch.stack(res_tmp)
|
||||
.reshape(self._input_numb_field, self._output_numb_field, -1)
|
||||
.sum(0)
|
||||
)
|
||||
conv_transposed[batch_idx, ..., -1] = res_tmp
|
||||
|
||||
return conv_transposed
|
||||
@@ -460,10 +481,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='inverse')
|
||||
self._choose_initialization(X, type="inverse")
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'inverse')
|
||||
self._initialize_convolution(X, "inverse")
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
@@ -479,11 +500,14 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# accumulator for the convolution on different batches
|
||||
accumulator_batch = torch.zeros(
|
||||
size=(self._grid_transpose.shape[1],
|
||||
self._grid_transpose.shape[2]),
|
||||
size=(
|
||||
self._grid_transpose.shape[1],
|
||||
self._grid_transpose.shape[2],
|
||||
),
|
||||
requires_grad=True,
|
||||
device=X.device,
|
||||
dtype=X.dtype).clone()
|
||||
dtype=X.dtype,
|
||||
).clone()
|
||||
|
||||
for stride_idx, current_stride in enumerate(self._stride):
|
||||
# indeces of points falling into filter range
|
||||
@@ -522,9 +546,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
staked_output = net(nn_input_pts[idx_channel_out])
|
||||
|
||||
# perform integral for all strides in one field
|
||||
integral = staked_output * integrals[batch_idx,
|
||||
idx_channel_in,
|
||||
stride_idx]
|
||||
integral = (
|
||||
staked_output
|
||||
* integrals[batch_idx, idx_channel_in, stride_idx]
|
||||
)
|
||||
# append results
|
||||
res_tmp.append(integral.flatten())
|
||||
|
||||
@@ -532,7 +557,7 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
channel_sum = []
|
||||
start = 0
|
||||
for _ in range(self._output_numb_field):
|
||||
tmp = res_tmp[start:start + self._input_numb_field]
|
||||
tmp = res_tmp[start : start + self._input_numb_field]
|
||||
tmp = torch.vstack(tmp).sum(dim=0)
|
||||
channel_sum.append(tmp)
|
||||
start += self._input_numb_field
|
||||
|
||||
Reference in New Issue
Block a user