🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -1,4 +1,5 @@
"""Module for Continuous Convolution class"""
from .convolution import BaseContinuousConv
from .utils_convolution import check_point, map_points_
from .integral import Integral
@@ -31,14 +32,16 @@ class ContinuousConvBlock(BaseContinuousConv):
"""
def __init__(self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False):
def __init__(
self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False,
):
"""
:param input_numb_field: Number of fields :math:`N_{in}` in the input.
:type input_numb_field: int
@@ -112,16 +115,18 @@ class ContinuousConvBlock(BaseContinuousConv):
)
"""
super().__init__(input_numb_field=input_numb_field,
output_numb_field=output_numb_field,
filter_dim=filter_dim,
stride=stride,
model=model,
optimize=optimize,
no_overlap=no_overlap)
super().__init__(
input_numb_field=input_numb_field,
output_numb_field=output_numb_field,
filter_dim=filter_dim,
stride=stride,
model=model,
optimize=optimize,
no_overlap=no_overlap,
)
# integral routine
self._integral = Integral('discrete')
self._integral = Integral("discrete")
# create the network
self._net = self._spawn_networks(model)
@@ -146,15 +151,18 @@ class ContinuousConvBlock(BaseContinuousConv):
nets.append(tmp)
else:
if not isinstance(model, object):
raise ValueError("Expected a python class inheriting"
" from torch.nn.Module")
raise ValueError(
"Expected a python class inheriting" " from torch.nn.Module"
)
for _ in range(self._input_numb_field * self._output_numb_field):
tmp = model()
if not isinstance(tmp, torch.nn.Module):
raise ValueError("The python class must be inherited from"
" torch.nn.Module. See the docstring for"
" an example.")
raise ValueError(
"The python class must be inherited from"
" torch.nn.Module. See the docstring for"
" an example."
)
nets.append(tmp)
return torch.nn.ModuleList(nets)
@@ -232,11 +240,17 @@ class ContinuousConvBlock(BaseContinuousConv):
number_points = len(self._stride)
# initialize the grid
grid = torch.zeros(size=(X.shape[0], self._output_numb_field,
number_points, filter_dim + 1),
device=X.device,
dtype=X.dtype)
grid[..., :-1] = (self._stride + self._dim * 0.5)
grid = torch.zeros(
size=(
X.shape[0],
self._output_numb_field,
number_points,
filter_dim + 1,
),
device=X.device,
dtype=X.dtype,
)
grid[..., :-1] = self._stride + self._dim * 0.5
# saving the grid
self._grid = grid.detach()
@@ -269,14 +283,14 @@ class ContinuousConvBlock(BaseContinuousConv):
"""
# choose the type of convolution
if type == 'forward':
if type == "forward":
return self._make_grid_forward(X)
elif type == 'inverse':
elif type == "inverse":
self._make_grid_transpose(X)
else:
raise TypeError
def _initialize_convolution(self, X, type='forward'):
def _initialize_convolution(self, X, type="forward"):
"""
Private method to intialize the convolution.
The convolution is initialized by setting a grid and
@@ -307,10 +321,10 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='forward')
self._choose_initialization(X, type="forward")
else: # we always initialize on testing
self._initialize_convolution(X, 'forward')
self._initialize_convolution(X, "forward")
# create convolutional array
conv = self._grid.clone().detach()
@@ -322,7 +336,8 @@ class ContinuousConvBlock(BaseContinuousConv):
# extract mapped points
stacked_input, indeces_channels = self._extract_mapped_points(
batch_idx, self._index, x)
batch_idx, self._index, x
)
# compute the convolution
@@ -339,9 +354,11 @@ class ContinuousConvBlock(BaseContinuousConv):
# calculate filter value
staked_output = net(single_channel_input[..., :-1])
# perform integral for all strides in one field
integral = self._integral(staked_output,
single_channel_input[..., -1],
indeces_channels[idx])
integral = self._integral(
staked_output,
single_channel_input[..., -1],
indeces_channels[idx],
)
res_tmp.append(integral)
# stacking integral results
@@ -349,9 +366,9 @@ class ContinuousConvBlock(BaseContinuousConv):
# sum filters (for each input fields) in groups
# for different ouput fields
conv[batch_idx, ...,
-1] = res_tmp.reshape(self._output_numb_field,
self._input_numb_field, -1).sum(1)
conv[batch_idx, ..., -1] = res_tmp.reshape(
self._output_numb_field, self._input_numb_field, -1
).sum(1)
return conv
def transpose_no_overlap(self, integrals, X):
@@ -382,10 +399,10 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='inverse')
self._choose_initialization(X, type="inverse")
else: # we always initialize on testing
self._initialize_convolution(X, 'inverse')
self._initialize_convolution(X, "inverse")
# initialize grid
X = self._grid_transpose.clone().detach()
@@ -398,7 +415,8 @@ class ContinuousConvBlock(BaseContinuousConv):
# extract mapped points
stacked_input, indeces_channels = self._extract_mapped_points(
batch_idx, self._index, x)
batch_idx, self._index, x
)
# compute the transpose convolution
@@ -414,8 +432,9 @@ class ContinuousConvBlock(BaseContinuousConv):
# extract input for each field
single_channel_input = stacked_input[idx]
rep_idx = torch.tensor(indeces_channels[idx])
integral = integrals[batch_idx,
idx_in, :].repeat_interleave(rep_idx)
integral = integrals[batch_idx, idx_in, :].repeat_interleave(
rep_idx
)
# extract filter
net = self._net[idx_conv]
# perform transpose convolution for all strides in one field
@@ -426,9 +445,11 @@ class ContinuousConvBlock(BaseContinuousConv):
# stacking integral results and sum
# filters (for each input fields) in groups
# for different output fields
res_tmp = torch.stack(res_tmp).reshape(self._input_numb_field,
self._output_numb_field,
-1).sum(0)
res_tmp = (
torch.stack(res_tmp)
.reshape(self._input_numb_field, self._output_numb_field, -1)
.sum(0)
)
conv_transposed[batch_idx, ..., -1] = res_tmp
return conv_transposed
@@ -460,10 +481,10 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='inverse')
self._choose_initialization(X, type="inverse")
else: # we always initialize on testing
self._initialize_convolution(X, 'inverse')
self._initialize_convolution(X, "inverse")
# initialize grid
X = self._grid_transpose.clone().detach()
@@ -479,11 +500,14 @@ class ContinuousConvBlock(BaseContinuousConv):
# accumulator for the convolution on different batches
accumulator_batch = torch.zeros(
size=(self._grid_transpose.shape[1],
self._grid_transpose.shape[2]),
size=(
self._grid_transpose.shape[1],
self._grid_transpose.shape[2],
),
requires_grad=True,
device=X.device,
dtype=X.dtype).clone()
dtype=X.dtype,
).clone()
for stride_idx, current_stride in enumerate(self._stride):
# indeces of points falling into filter range
@@ -522,9 +546,10 @@ class ContinuousConvBlock(BaseContinuousConv):
staked_output = net(nn_input_pts[idx_channel_out])
# perform integral for all strides in one field
integral = staked_output * integrals[batch_idx,
idx_channel_in,
stride_idx]
integral = (
staked_output
* integrals[batch_idx, idx_channel_in, stride_idx]
)
# append results
res_tmp.append(integral.flatten())
@@ -532,7 +557,7 @@ class ContinuousConvBlock(BaseContinuousConv):
channel_sum = []
start = 0
for _ in range(self._output_numb_field):
tmp = res_tmp[start:start + self._input_numb_field]
tmp = res_tmp[start : start + self._input_numb_field]
tmp = torch.vstack(tmp).sum(dim=0)
channel_sum.append(tmp)
start += self._input_numb_field