🎨 Format Python code with psf/black
This commit is contained in:
@@ -37,18 +37,23 @@ class SpectralConvBlock1D(nn.Module):
|
||||
self._output_channels = output_numb_fields
|
||||
|
||||
# scaling factor
|
||||
scale = (1. / (self._input_channels * self._output_channels))
|
||||
self._weights = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes,
|
||||
dtype=torch.cfloat))
|
||||
scale = 1.0 / (self._input_channels * self._output_channels)
|
||||
self._weights = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes,
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
|
||||
def _compute_mult1d(self, input, weights):
|
||||
"""
|
||||
Compute the matrix multiplication of the input
|
||||
with the linear kernel weights.
|
||||
|
||||
:param input: The input tensor, expect of size
|
||||
:param input: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x]``.
|
||||
:type input: torch.Tensor
|
||||
:param weights: The kernel weights, expect of
|
||||
@@ -64,7 +69,7 @@ class SpectralConvBlock1D(nn.Module):
|
||||
"""
|
||||
Forward computation for Spectral Convolution.
|
||||
|
||||
:param x: The input tensor, expect of size
|
||||
:param x: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
@@ -77,13 +82,16 @@ class SpectralConvBlock1D(nn.Module):
|
||||
x_ft = torch.fft.rfft(x)
|
||||
|
||||
# Multiply relevant Fourier modes
|
||||
out_ft = torch.zeros(batch_size,
|
||||
self._output_channels,
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat)
|
||||
out_ft[:, :, :self._modes] = self._compute_mult1d(
|
||||
x_ft[:, :, :self._modes], self._weights)
|
||||
out_ft = torch.zeros(
|
||||
batch_size,
|
||||
self._output_channels,
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
out_ft[:, :, : self._modes] = self._compute_mult1d(
|
||||
x_ft[:, :, : self._modes], self._weights
|
||||
)
|
||||
|
||||
# Return to physical space
|
||||
return torch.fft.irfft(out_ft, n=x.size(-1))
|
||||
@@ -119,17 +127,19 @@ class SpectralConvBlock2D(nn.Module):
|
||||
if isinstance(n_modes, (tuple, list)):
|
||||
if len(n_modes) != 2:
|
||||
raise ValueError(
|
||||
'Expected n_modes to be a list or tuple of len two, '
|
||||
'with each entry corresponding to the number of modes '
|
||||
'for each dimension ')
|
||||
"Expected n_modes to be a list or tuple of len two, "
|
||||
"with each entry corresponding to the number of modes "
|
||||
"for each dimension "
|
||||
)
|
||||
elif isinstance(n_modes, int):
|
||||
n_modes = [n_modes] * 2
|
||||
else:
|
||||
raise ValueError(
|
||||
'Expected n_modes to be a list or tuple of len two, '
|
||||
'with each entry corresponding to the number of modes '
|
||||
'for each dimension; or an int value representing the '
|
||||
'number of modes for all dimensions')
|
||||
"Expected n_modes to be a list or tuple of len two, "
|
||||
"with each entry corresponding to the number of modes "
|
||||
"for each dimension; or an int value representing the "
|
||||
"number of modes for all dimensions"
|
||||
)
|
||||
|
||||
# assign variables
|
||||
self._modes = n_modes
|
||||
@@ -137,24 +147,34 @@ class SpectralConvBlock2D(nn.Module):
|
||||
self._output_channels = output_numb_fields
|
||||
|
||||
# scaling factor
|
||||
scale = (1. / (self._input_channels * self._output_channels))
|
||||
self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
dtype=torch.cfloat))
|
||||
self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
dtype=torch.cfloat))
|
||||
scale = 1.0 / (self._input_channels * self._output_channels)
|
||||
self._weights1 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
self._weights2 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
|
||||
def _compute_mult2d(self, input, weights):
|
||||
"""
|
||||
Compute the matrix multiplication of the input
|
||||
with the linear kernel weights.
|
||||
|
||||
:param input: The input tensor, expect of size
|
||||
:param input: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x, y]``.
|
||||
:type input: torch.Tensor
|
||||
:param weights: The kernel weights, expect of
|
||||
@@ -170,7 +190,7 @@ class SpectralConvBlock2D(nn.Module):
|
||||
"""
|
||||
Forward computation for Spectral Convolution.
|
||||
|
||||
:param x: The input tensor, expect of size
|
||||
:param x: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x, y]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
@@ -184,16 +204,22 @@ class SpectralConvBlock2D(nn.Module):
|
||||
x_ft = torch.fft.rfft2(x)
|
||||
|
||||
# Multiply relevant Fourier modes
|
||||
out_ft = torch.zeros(batch_size,
|
||||
self._output_channels,
|
||||
x.size(-2),
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat)
|
||||
out_ft[:, :, :self._modes[0], :self._modes[1]] = self._compute_mult2d(
|
||||
x_ft[:, :, :self._modes[0], :self._modes[1]], self._weights1)
|
||||
out_ft[:, :, -self._modes[0]:, :self._modes[1]:] = self._compute_mult2d(
|
||||
x_ft[:, :, -self._modes[0]:, :self._modes[1]], self._weights2)
|
||||
out_ft = torch.zeros(
|
||||
batch_size,
|
||||
self._output_channels,
|
||||
x.size(-2),
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
out_ft[:, :, : self._modes[0], : self._modes[1]] = self._compute_mult2d(
|
||||
x_ft[:, :, : self._modes[0], : self._modes[1]], self._weights1
|
||||
)
|
||||
out_ft[:, :, -self._modes[0] :, : self._modes[1] :] = (
|
||||
self._compute_mult2d(
|
||||
x_ft[:, :, -self._modes[0] :, : self._modes[1]], self._weights2
|
||||
)
|
||||
)
|
||||
|
||||
# Return to physical space
|
||||
return torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
|
||||
@@ -230,17 +256,19 @@ class SpectralConvBlock3D(nn.Module):
|
||||
if isinstance(n_modes, (tuple, list)):
|
||||
if len(n_modes) != 3:
|
||||
raise ValueError(
|
||||
'Expected n_modes to be a list or tuple of len three, '
|
||||
'with each entry corresponding to the number of modes '
|
||||
'for each dimension ')
|
||||
"Expected n_modes to be a list or tuple of len three, "
|
||||
"with each entry corresponding to the number of modes "
|
||||
"for each dimension "
|
||||
)
|
||||
elif isinstance(n_modes, int):
|
||||
n_modes = [n_modes] * 3
|
||||
else:
|
||||
raise ValueError(
|
||||
'Expected n_modes to be a list or tuple of len three, '
|
||||
'with each entry corresponding to the number of modes '
|
||||
'for each dimension; or an int value representing the '
|
||||
'number of modes for all dimensions')
|
||||
"Expected n_modes to be a list or tuple of len three, "
|
||||
"with each entry corresponding to the number of modes "
|
||||
"for each dimension; or an int value representing the "
|
||||
"number of modes for all dimensions"
|
||||
)
|
||||
|
||||
# assign variables
|
||||
self._modes = n_modes
|
||||
@@ -248,38 +276,58 @@ class SpectralConvBlock3D(nn.Module):
|
||||
self._output_channels = output_numb_fields
|
||||
|
||||
# scaling factor
|
||||
scale = (1. / (self._input_channels * self._output_channels))
|
||||
self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat))
|
||||
self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat))
|
||||
self._weights3 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat))
|
||||
self._weights4 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat))
|
||||
scale = 1.0 / (self._input_channels * self._output_channels)
|
||||
self._weights1 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
self._weights2 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
self._weights3 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
self._weights4 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
|
||||
def _compute_mult3d(self, input, weights):
|
||||
"""
|
||||
Compute the matrix multiplication of the input
|
||||
with the linear kernel weights.
|
||||
|
||||
:param input: The input tensor, expect of size
|
||||
:param input: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x, y, z]``.
|
||||
:type input: torch.Tensor
|
||||
:param weights: The kernel weights, expect of
|
||||
@@ -295,7 +343,7 @@ class SpectralConvBlock3D(nn.Module):
|
||||
"""
|
||||
Forward computation for Spectral Convolution.
|
||||
|
||||
:param x: The input tensor, expect of size
|
||||
:param x: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x, y, z]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
@@ -309,13 +357,15 @@ class SpectralConvBlock3D(nn.Module):
|
||||
x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1])
|
||||
|
||||
# Multiply relevant Fourier modes
|
||||
out_ft = torch.zeros(batch_size,
|
||||
self._output_channels,
|
||||
x.size(-3),
|
||||
x.size(-2),
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat)
|
||||
out_ft = torch.zeros(
|
||||
batch_size,
|
||||
self._output_channels,
|
||||
x.size(-3),
|
||||
x.size(-2),
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
|
||||
slice0 = (
|
||||
slice(None),
|
||||
|
||||
Reference in New Issue
Block a user