🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -1,18 +1,22 @@
__all__ = [
'ContinuousConvBlock',
'ResidualBlock',
'EnhancedLinear',
'SpectralConvBlock1D',
'SpectralConvBlock2D',
'SpectralConvBlock3D',
'FourierBlock1D',
'FourierBlock2D',
'FourierBlock3D',
'PODLayer'
"ContinuousConvBlock",
"ResidualBlock",
"EnhancedLinear",
"SpectralConvBlock1D",
"SpectralConvBlock2D",
"SpectralConvBlock3D",
"FourierBlock1D",
"FourierBlock2D",
"FourierBlock3D",
"PODLayer",
]
from .convolution_2d import ContinuousConvBlock
from .residual import ResidualBlock, EnhancedLinear
from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
from .spectral import (
SpectralConvBlock1D,
SpectralConvBlock2D,
SpectralConvBlock3D,
)
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
from .pod import PODLayer

View File

@@ -1,4 +1,5 @@
"""Module for Base Continuous Convolution class."""
from abc import ABCMeta, abstractmethod
import torch
from .stride import Stride
@@ -10,14 +11,16 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
Abstract class
"""
def __init__(self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False):
def __init__(
self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False,
):
"""
Base Class for Continuous Convolution.
@@ -75,43 +78,44 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
if isinstance(input_numb_field, int):
self._input_numb_field = input_numb_field
else:
raise ValueError('input_numb_field must be int.')
raise ValueError("input_numb_field must be int.")
if isinstance(output_numb_field, int):
self._output_numb_field = output_numb_field
else:
raise ValueError('input_numb_field must be int.')
raise ValueError("input_numb_field must be int.")
if isinstance(filter_dim, (tuple, list)):
vect = filter_dim
else:
raise ValueError('filter_dim must be tuple or list.')
raise ValueError("filter_dim must be tuple or list.")
vect = torch.tensor(vect)
self.register_buffer("_dim", vect, persistent=False)
if isinstance(stride, dict):
self._stride = Stride(stride)
else:
raise ValueError('stride must be dictionary.')
raise ValueError("stride must be dictionary.")
self._net = model
if isinstance(optimize, bool):
self._optimize = optimize
else:
raise ValueError('optimize must be bool.')
raise ValueError("optimize must be bool.")
# choosing how to initialize based on optimization
if self._optimize:
# optimizing decorator ensure the function is called
# just once
self._choose_initialization = optimizing(
self._initialize_convolution)
self._initialize_convolution
)
else:
self._choose_initialization = self._initialize_convolution
if not isinstance(no_overlap, bool):
raise ValueError('no_overlap must be bool.')
raise ValueError("no_overlap must be bool.")
if no_overlap:
raise NotImplementedError
@@ -125,11 +129,13 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
super().__init__()
assert isinstance(input_dim, int)
assert isinstance(output_dim, int)
self._model = torch.nn.Sequential(torch.nn.Linear(input_dim, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, output_dim))
self._model = torch.nn.Sequential(
torch.nn.Linear(input_dim, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, output_dim),
)
def forward(self, x):
return self._model(x)

View File

@@ -1,4 +1,5 @@
"""Module for Continuous Convolution class"""
from .convolution import BaseContinuousConv
from .utils_convolution import check_point, map_points_
from .integral import Integral
@@ -31,14 +32,16 @@ class ContinuousConvBlock(BaseContinuousConv):
"""
def __init__(self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False):
def __init__(
self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False,
):
"""
:param input_numb_field: Number of fields :math:`N_{in}` in the input.
:type input_numb_field: int
@@ -112,16 +115,18 @@ class ContinuousConvBlock(BaseContinuousConv):
)
"""
super().__init__(input_numb_field=input_numb_field,
output_numb_field=output_numb_field,
filter_dim=filter_dim,
stride=stride,
model=model,
optimize=optimize,
no_overlap=no_overlap)
super().__init__(
input_numb_field=input_numb_field,
output_numb_field=output_numb_field,
filter_dim=filter_dim,
stride=stride,
model=model,
optimize=optimize,
no_overlap=no_overlap,
)
# integral routine
self._integral = Integral('discrete')
self._integral = Integral("discrete")
# create the network
self._net = self._spawn_networks(model)
@@ -146,15 +151,18 @@ class ContinuousConvBlock(BaseContinuousConv):
nets.append(tmp)
else:
if not isinstance(model, object):
raise ValueError("Expected a python class inheriting"
" from torch.nn.Module")
raise ValueError(
"Expected a python class inheriting" " from torch.nn.Module"
)
for _ in range(self._input_numb_field * self._output_numb_field):
tmp = model()
if not isinstance(tmp, torch.nn.Module):
raise ValueError("The python class must be inherited from"
" torch.nn.Module. See the docstring for"
" an example.")
raise ValueError(
"The python class must be inherited from"
" torch.nn.Module. See the docstring for"
" an example."
)
nets.append(tmp)
return torch.nn.ModuleList(nets)
@@ -232,11 +240,17 @@ class ContinuousConvBlock(BaseContinuousConv):
number_points = len(self._stride)
# initialize the grid
grid = torch.zeros(size=(X.shape[0], self._output_numb_field,
number_points, filter_dim + 1),
device=X.device,
dtype=X.dtype)
grid[..., :-1] = (self._stride + self._dim * 0.5)
grid = torch.zeros(
size=(
X.shape[0],
self._output_numb_field,
number_points,
filter_dim + 1,
),
device=X.device,
dtype=X.dtype,
)
grid[..., :-1] = self._stride + self._dim * 0.5
# saving the grid
self._grid = grid.detach()
@@ -269,14 +283,14 @@ class ContinuousConvBlock(BaseContinuousConv):
"""
# choose the type of convolution
if type == 'forward':
if type == "forward":
return self._make_grid_forward(X)
elif type == 'inverse':
elif type == "inverse":
self._make_grid_transpose(X)
else:
raise TypeError
def _initialize_convolution(self, X, type='forward'):
def _initialize_convolution(self, X, type="forward"):
"""
Private method to intialize the convolution.
The convolution is initialized by setting a grid and
@@ -307,10 +321,10 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='forward')
self._choose_initialization(X, type="forward")
else: # we always initialize on testing
self._initialize_convolution(X, 'forward')
self._initialize_convolution(X, "forward")
# create convolutional array
conv = self._grid.clone().detach()
@@ -322,7 +336,8 @@ class ContinuousConvBlock(BaseContinuousConv):
# extract mapped points
stacked_input, indeces_channels = self._extract_mapped_points(
batch_idx, self._index, x)
batch_idx, self._index, x
)
# compute the convolution
@@ -339,9 +354,11 @@ class ContinuousConvBlock(BaseContinuousConv):
# calculate filter value
staked_output = net(single_channel_input[..., :-1])
# perform integral for all strides in one field
integral = self._integral(staked_output,
single_channel_input[..., -1],
indeces_channels[idx])
integral = self._integral(
staked_output,
single_channel_input[..., -1],
indeces_channels[idx],
)
res_tmp.append(integral)
# stacking integral results
@@ -349,9 +366,9 @@ class ContinuousConvBlock(BaseContinuousConv):
# sum filters (for each input fields) in groups
# for different ouput fields
conv[batch_idx, ...,
-1] = res_tmp.reshape(self._output_numb_field,
self._input_numb_field, -1).sum(1)
conv[batch_idx, ..., -1] = res_tmp.reshape(
self._output_numb_field, self._input_numb_field, -1
).sum(1)
return conv
def transpose_no_overlap(self, integrals, X):
@@ -382,10 +399,10 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='inverse')
self._choose_initialization(X, type="inverse")
else: # we always initialize on testing
self._initialize_convolution(X, 'inverse')
self._initialize_convolution(X, "inverse")
# initialize grid
X = self._grid_transpose.clone().detach()
@@ -398,7 +415,8 @@ class ContinuousConvBlock(BaseContinuousConv):
# extract mapped points
stacked_input, indeces_channels = self._extract_mapped_points(
batch_idx, self._index, x)
batch_idx, self._index, x
)
# compute the transpose convolution
@@ -414,8 +432,9 @@ class ContinuousConvBlock(BaseContinuousConv):
# extract input for each field
single_channel_input = stacked_input[idx]
rep_idx = torch.tensor(indeces_channels[idx])
integral = integrals[batch_idx,
idx_in, :].repeat_interleave(rep_idx)
integral = integrals[batch_idx, idx_in, :].repeat_interleave(
rep_idx
)
# extract filter
net = self._net[idx_conv]
# perform transpose convolution for all strides in one field
@@ -426,9 +445,11 @@ class ContinuousConvBlock(BaseContinuousConv):
# stacking integral results and sum
# filters (for each input fields) in groups
# for different output fields
res_tmp = torch.stack(res_tmp).reshape(self._input_numb_field,
self._output_numb_field,
-1).sum(0)
res_tmp = (
torch.stack(res_tmp)
.reshape(self._input_numb_field, self._output_numb_field, -1)
.sum(0)
)
conv_transposed[batch_idx, ..., -1] = res_tmp
return conv_transposed
@@ -460,10 +481,10 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type='inverse')
self._choose_initialization(X, type="inverse")
else: # we always initialize on testing
self._initialize_convolution(X, 'inverse')
self._initialize_convolution(X, "inverse")
# initialize grid
X = self._grid_transpose.clone().detach()
@@ -479,11 +500,14 @@ class ContinuousConvBlock(BaseContinuousConv):
# accumulator for the convolution on different batches
accumulator_batch = torch.zeros(
size=(self._grid_transpose.shape[1],
self._grid_transpose.shape[2]),
size=(
self._grid_transpose.shape[1],
self._grid_transpose.shape[2],
),
requires_grad=True,
device=X.device,
dtype=X.dtype).clone()
dtype=X.dtype,
).clone()
for stride_idx, current_stride in enumerate(self._stride):
# indeces of points falling into filter range
@@ -522,9 +546,10 @@ class ContinuousConvBlock(BaseContinuousConv):
staked_output = net(nn_input_pts[idx_channel_out])
# perform integral for all strides in one field
integral = staked_output * integrals[batch_idx,
idx_channel_in,
stride_idx]
integral = (
staked_output
* integrals[batch_idx, idx_channel_in, stride_idx]
)
# append results
res_tmp.append(integral.flatten())
@@ -532,7 +557,7 @@ class ContinuousConvBlock(BaseContinuousConv):
channel_sum = []
start = 0
for _ in range(self._output_numb_field):
tmp = res_tmp[start:start + self._input_numb_field]
tmp = res_tmp[start : start + self._input_numb_field]
tmp = torch.vstack(tmp).sum(dim=0)
channel_sum.append(tmp)
start += self._input_numb_field

View File

@@ -2,14 +2,18 @@ import torch
import torch.nn as nn
from ...utils import check_consistency
from pina.model.layers import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
from pina.model.layers import (
SpectralConvBlock1D,
SpectralConvBlock2D,
SpectralConvBlock3D,
)
class FourierBlock1D(nn.Module):
"""
Fourier block implementation for three dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
make up the Fourier Neural Operator
.. seealso::
@@ -21,11 +25,13 @@ class FourierBlock1D(nn.Module):
"""
def __init__(self,
input_numb_fields,
output_numb_fields,
n_modes,
activation=torch.nn.Tanh):
def __init__(
self,
input_numb_fields,
output_numb_fields,
n_modes,
activation=torch.nn.Tanh,
):
super().__init__()
"""
PINA implementation of Fourier block one dimension. The module computes
@@ -51,17 +57,18 @@ class FourierBlock1D(nn.Module):
self._spectral_conv = SpectralConvBlock1D(
input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
n_modes=n_modes,
)
self._activation = activation()
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
"""
Forward computation for Fourier Block. It performs a spectral
Forward computation for Fourier Block. It performs a spectral
convolution and a linear transformation of the input and sum the
results.
:param x: The input tensor for fourier block, expect of size
:param x: The input tensor for fourier block, expect of size
``[batch, input_numb_fields, x]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
@@ -75,7 +82,7 @@ class FourierBlock2D(nn.Module):
"""
Fourier block implementation for two dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
make up the Fourier Neural Operator
.. seealso::
@@ -87,18 +94,20 @@ class FourierBlock2D(nn.Module):
"""
def __init__(self,
input_numb_fields,
output_numb_fields,
n_modes,
activation=torch.nn.Tanh):
def __init__(
self,
input_numb_fields,
output_numb_fields,
n_modes,
activation=torch.nn.Tanh,
):
"""
PINA implementation of Fourier block two dimensions. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
@@ -118,17 +127,18 @@ class FourierBlock2D(nn.Module):
self._spectral_conv = SpectralConvBlock2D(
input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
n_modes=n_modes,
)
self._activation = activation()
self._linear = nn.Conv2d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
"""
Forward computation for Fourier Block. It performs a spectral
Forward computation for Fourier Block. It performs a spectral
convolution and a linear transformation of the input and sum the
results.
:param x: The input tensor for fourier block, expect of size
:param x: The input tensor for fourier block, expect of size
``[batch, input_numb_fields, x, y]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
@@ -142,7 +152,7 @@ class FourierBlock3D(nn.Module):
"""
Fourier block implementation for three dimensional
input tensor. The combination of Fourier blocks
make up the Fourier Neural Operator
make up the Fourier Neural Operator
.. seealso::
@@ -154,18 +164,20 @@ class FourierBlock3D(nn.Module):
"""
def __init__(self,
input_numb_fields,
output_numb_fields,
n_modes,
activation=torch.nn.Tanh):
def __init__(
self,
input_numb_fields,
output_numb_fields,
n_modes,
activation=torch.nn.Tanh,
):
"""
PINA implementation of Fourier block three dimensions. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
@@ -186,17 +198,18 @@ class FourierBlock3D(nn.Module):
self._spectral_conv = SpectralConvBlock3D(
input_numb_fields=input_numb_fields,
output_numb_fields=output_numb_fields,
n_modes=n_modes)
n_modes=n_modes,
)
self._activation = activation()
self._linear = nn.Conv3d(input_numb_fields, output_numb_fields, 1)
def forward(self, x):
"""
Forward computation for Fourier Block. It performs a spectral
Forward computation for Fourier Block. It performs a spectral
convolution and a linear transformation of the input and sum the
results.
:param x: The input tensor for fourier block, expect of size
:param x: The input tensor for fourier block, expect of size
``[batch, input_numb_fields, x, y, z]``.
:type x: torch.Tensor
:return: The output tensor obtained from the

View File

@@ -10,9 +10,9 @@ class Integral(object):
:type param: string
"""
if param == 'discrete':
if param == "discrete":
self.make_integral = self.integral_param_disc
elif param == 'continuous':
elif param == "continuous":
self.make_integral = self.integral_param_cont
else:
raise TypeError

View File

@@ -1,4 +1,5 @@
"""Module for Base Continuous Convolution class."""
from abc import ABCMeta, abstractmethod
import torch
from .stride import Stride
@@ -38,17 +39,17 @@ class PODLayer(torch.nn.Module):
:rtype: int
"""
return self._rank
@rank.setter
def rank(self, value):
if value < 1 or not isinstance(value, int):
raise ValueError('The rank must be positive integer')
raise ValueError("The rank must be positive integer")
self._rank = value
@property
def basis(self):
"""
"""
The POD basis. It is a matrix whose columns are the first `self.rank` POD modes.
:rtype: torch.Tensor
@@ -56,7 +57,7 @@ class PODLayer(torch.nn.Module):
if self._basis is None:
return None
return self._basis[:self.rank]
return self._basis[: self.rank]
@property
def scaler(self):
@@ -67,10 +68,12 @@ class PODLayer(torch.nn.Module):
:rtype: dict
"""
if self._scaler is None:
return
return
return {'mean': self._scaler['mean'][:self.rank],
'std': self._scaler['std'][:self.rank]}
return {
"mean": self._scaler["mean"][: self.rank],
"std": self._scaler["std"][: self.rank],
}
@property
def scale_coefficients(self):
@@ -105,8 +108,9 @@ class PODLayer(torch.nn.Module):
:param torch.Tensor coeffs: The coefficients to be scaled.
"""
self._scaler = {
'std': torch.std(coeffs, dim=1),
'mean': torch.mean(coeffs, dim=1)}
"std": torch.std(coeffs, dim=1),
"mean": torch.mean(coeffs, dim=1),
}
def _fit_pod(self, X):
"""
@@ -114,7 +118,7 @@ class PODLayer(torch.nn.Module):
:param torch.Tensor X: The tensor to be reduced.
"""
if X.device.type == 'mps': # svd_lowrank not arailable for mps
if X.device.type == "mps": # svd_lowrank not arailable for mps
self._basis = torch.svd(X.T)[0].T
else:
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
@@ -142,7 +146,8 @@ class PODLayer(torch.nn.Module):
"""
if self._basis is None:
raise RuntimeError(
'The POD layer needs to be fitted before being used.')
"The POD layer needs to be fitted before being used."
)
coeff = torch.matmul(self.basis, X.T)
if coeff.ndim == 1:
@@ -150,28 +155,29 @@ class PODLayer(torch.nn.Module):
coeff = coeff.T
if self.__scale_coefficients:
coeff = (coeff - self.scaler['mean']) / self.scaler['std']
coeff = (coeff - self.scaler["mean"]) / self.scaler["std"]
return coeff
def expand(self, coeff):
"""
"""
Expand the given coefficients to the original space. The POD layer needs
to be fitted before being used.
:param torch.Tensor coeff: The coefficients to be expanded.
:return: The expanded tensor.
:rtype: torch.Tensor
"""
if self._basis is None:
raise RuntimeError(
'The POD layer needs to be trained before being used.')
"The POD layer needs to be trained before being used."
)
if self.__scale_coefficients:
coeff = coeff * self.scaler['std'] + self.scaler['mean']
coeff = coeff * self.scaler["std"] + self.scaler["mean"]
predicted = torch.matmul(self.basis.T, coeff.T).T
if predicted.ndim == 1:
predicted = predicted.unsqueeze(0)
return predicted
return predicted

View File

@@ -16,18 +16,20 @@ class ResidualBlock(nn.Module):
"""
def __init__(self,
input_dim,
output_dim,
hidden_dim,
spectral_norm=False,
activation=torch.nn.ReLU()):
def __init__(
self,
input_dim,
output_dim,
hidden_dim,
spectral_norm=False,
activation=torch.nn.ReLU(),
):
"""
Initializes the ResidualBlock module.
:param int input_dim: Dimension of the input to pass to the
feedforward linear layer.
:param int output_dim: Dimension of the output from the
:param int output_dim: Dimension of the output from the
residual layer.
:param int hidden_dim: Hidden dimension for mapping the input
(first block).
@@ -82,6 +84,7 @@ class ResidualBlock(nn.Module):
import torch
import torch.nn as nn
class EnhancedLinear(torch.nn.Module):
"""
A wrapper class for enhancing a linear layer with activation and/or dropout.
@@ -132,8 +135,9 @@ class EnhancedLinear(torch.nn.Module):
self._model = torch.nn.Sequential(layer, self._drop(dropout))
elif (dropout is not None) and (activation is not None):
self._model = torch.nn.Sequential(layer, activation,
self._drop(dropout))
self._model = torch.nn.Sequential(
layer, activation, self._drop(dropout)
)
def forward(self, x):
"""

View File

@@ -37,18 +37,23 @@ class SpectralConvBlock1D(nn.Module):
self._output_channels = output_numb_fields
# scaling factor
scale = (1. / (self._input_channels * self._output_channels))
self._weights = nn.Parameter(scale * torch.rand(self._input_channels,
self._output_channels,
self._modes,
dtype=torch.cfloat))
scale = 1.0 / (self._input_channels * self._output_channels)
self._weights = nn.Parameter(
scale
* torch.rand(
self._input_channels,
self._output_channels,
self._modes,
dtype=torch.cfloat,
)
)
def _compute_mult1d(self, input, weights):
"""
Compute the matrix multiplication of the input
with the linear kernel weights.
:param input: The input tensor, expect of size
:param input: The input tensor, expect of size
``[batch, input_numb_fields, x]``.
:type input: torch.Tensor
:param weights: The kernel weights, expect of
@@ -64,7 +69,7 @@ class SpectralConvBlock1D(nn.Module):
"""
Forward computation for Spectral Convolution.
:param x: The input tensor, expect of size
:param x: The input tensor, expect of size
``[batch, input_numb_fields, x]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
@@ -77,13 +82,16 @@ class SpectralConvBlock1D(nn.Module):
x_ft = torch.fft.rfft(x)
# Multiply relevant Fourier modes
out_ft = torch.zeros(batch_size,
self._output_channels,
x.size(-1) // 2 + 1,
device=x.device,
dtype=torch.cfloat)
out_ft[:, :, :self._modes] = self._compute_mult1d(
x_ft[:, :, :self._modes], self._weights)
out_ft = torch.zeros(
batch_size,
self._output_channels,
x.size(-1) // 2 + 1,
device=x.device,
dtype=torch.cfloat,
)
out_ft[:, :, : self._modes] = self._compute_mult1d(
x_ft[:, :, : self._modes], self._weights
)
# Return to physical space
return torch.fft.irfft(out_ft, n=x.size(-1))
@@ -119,17 +127,19 @@ class SpectralConvBlock2D(nn.Module):
if isinstance(n_modes, (tuple, list)):
if len(n_modes) != 2:
raise ValueError(
'Expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension ')
"Expected n_modes to be a list or tuple of len two, "
"with each entry corresponding to the number of modes "
"for each dimension "
)
elif isinstance(n_modes, int):
n_modes = [n_modes] * 2
else:
raise ValueError(
'Expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension; or an int value representing the '
'number of modes for all dimensions')
"Expected n_modes to be a list or tuple of len two, "
"with each entry corresponding to the number of modes "
"for each dimension; or an int value representing the "
"number of modes for all dimensions"
)
# assign variables
self._modes = n_modes
@@ -137,24 +147,34 @@ class SpectralConvBlock2D(nn.Module):
self._output_channels = output_numb_fields
# scaling factor
scale = (1. / (self._input_channels * self._output_channels))
self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
dtype=torch.cfloat))
self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
dtype=torch.cfloat))
scale = 1.0 / (self._input_channels * self._output_channels)
self._weights1 = nn.Parameter(
scale
* torch.rand(
self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
dtype=torch.cfloat,
)
)
self._weights2 = nn.Parameter(
scale
* torch.rand(
self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
dtype=torch.cfloat,
)
)
def _compute_mult2d(self, input, weights):
"""
Compute the matrix multiplication of the input
with the linear kernel weights.
:param input: The input tensor, expect of size
:param input: The input tensor, expect of size
``[batch, input_numb_fields, x, y]``.
:type input: torch.Tensor
:param weights: The kernel weights, expect of
@@ -170,7 +190,7 @@ class SpectralConvBlock2D(nn.Module):
"""
Forward computation for Spectral Convolution.
:param x: The input tensor, expect of size
:param x: The input tensor, expect of size
``[batch, input_numb_fields, x, y]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
@@ -184,16 +204,22 @@ class SpectralConvBlock2D(nn.Module):
x_ft = torch.fft.rfft2(x)
# Multiply relevant Fourier modes
out_ft = torch.zeros(batch_size,
self._output_channels,
x.size(-2),
x.size(-1) // 2 + 1,
device=x.device,
dtype=torch.cfloat)
out_ft[:, :, :self._modes[0], :self._modes[1]] = self._compute_mult2d(
x_ft[:, :, :self._modes[0], :self._modes[1]], self._weights1)
out_ft[:, :, -self._modes[0]:, :self._modes[1]:] = self._compute_mult2d(
x_ft[:, :, -self._modes[0]:, :self._modes[1]], self._weights2)
out_ft = torch.zeros(
batch_size,
self._output_channels,
x.size(-2),
x.size(-1) // 2 + 1,
device=x.device,
dtype=torch.cfloat,
)
out_ft[:, :, : self._modes[0], : self._modes[1]] = self._compute_mult2d(
x_ft[:, :, : self._modes[0], : self._modes[1]], self._weights1
)
out_ft[:, :, -self._modes[0] :, : self._modes[1] :] = (
self._compute_mult2d(
x_ft[:, :, -self._modes[0] :, : self._modes[1]], self._weights2
)
)
# Return to physical space
return torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
@@ -230,17 +256,19 @@ class SpectralConvBlock3D(nn.Module):
if isinstance(n_modes, (tuple, list)):
if len(n_modes) != 3:
raise ValueError(
'Expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension ')
"Expected n_modes to be a list or tuple of len three, "
"with each entry corresponding to the number of modes "
"for each dimension "
)
elif isinstance(n_modes, int):
n_modes = [n_modes] * 3
else:
raise ValueError(
'Expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension; or an int value representing the '
'number of modes for all dimensions')
"Expected n_modes to be a list or tuple of len three, "
"with each entry corresponding to the number of modes "
"for each dimension; or an int value representing the "
"number of modes for all dimensions"
)
# assign variables
self._modes = n_modes
@@ -248,38 +276,58 @@ class SpectralConvBlock3D(nn.Module):
self._output_channels = output_numb_fields
# scaling factor
scale = (1. / (self._input_channels * self._output_channels))
self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
self._modes[2],
dtype=torch.cfloat))
self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
self._modes[2],
dtype=torch.cfloat))
self._weights3 = nn.Parameter(scale * torch.rand(self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
self._modes[2],
dtype=torch.cfloat))
self._weights4 = nn.Parameter(scale * torch.rand(self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
self._modes[2],
dtype=torch.cfloat))
scale = 1.0 / (self._input_channels * self._output_channels)
self._weights1 = nn.Parameter(
scale
* torch.rand(
self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
self._modes[2],
dtype=torch.cfloat,
)
)
self._weights2 = nn.Parameter(
scale
* torch.rand(
self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
self._modes[2],
dtype=torch.cfloat,
)
)
self._weights3 = nn.Parameter(
scale
* torch.rand(
self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
self._modes[2],
dtype=torch.cfloat,
)
)
self._weights4 = nn.Parameter(
scale
* torch.rand(
self._input_channels,
self._output_channels,
self._modes[0],
self._modes[1],
self._modes[2],
dtype=torch.cfloat,
)
)
def _compute_mult3d(self, input, weights):
"""
Compute the matrix multiplication of the input
with the linear kernel weights.
:param input: The input tensor, expect of size
:param input: The input tensor, expect of size
``[batch, input_numb_fields, x, y, z]``.
:type input: torch.Tensor
:param weights: The kernel weights, expect of
@@ -295,7 +343,7 @@ class SpectralConvBlock3D(nn.Module):
"""
Forward computation for Spectral Convolution.
:param x: The input tensor, expect of size
:param x: The input tensor, expect of size
``[batch, input_numb_fields, x, y, z]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
@@ -309,13 +357,15 @@ class SpectralConvBlock3D(nn.Module):
x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1])
# Multiply relevant Fourier modes
out_ft = torch.zeros(batch_size,
self._output_channels,
x.size(-3),
x.size(-2),
x.size(-1) // 2 + 1,
device=x.device,
dtype=torch.cfloat)
out_ft = torch.zeros(
batch_size,
self._output_channels,
x.size(-3),
x.size(-2),
x.size(-1) // 2 + 1,
device=x.device,
dtype=torch.cfloat,
)
slice0 = (
slice(None),

View File

@@ -60,7 +60,8 @@ class Stride(object):
if seq_direction != seq_jumps:
raise IndexError(
"direction and jumps must have zero in the same index")
"direction and jumps must have zero in the same index"
)
if seq_jumps:
for i in seq_jumps:

View File

@@ -3,8 +3,9 @@ import torch
def check_point(x, current_stride, dim):
max_stride = current_stride + dim
indeces = torch.logical_and(x[..., :-1] < max_stride, x[..., :-1]
>= current_stride).all(dim=-1)
indeces = torch.logical_and(
x[..., :-1] < max_stride, x[..., :-1] >= current_stride
).all(dim=-1)
return indeces
@@ -32,12 +33,12 @@ def optimizing(f):
def wrapper(*args, **kwargs):
if kwargs['type'] == 'forward':
if kwargs["type"] == "forward":
if not wrapper.has_run_inverse:
wrapper.has_run_inverse = True
return f(*args, **kwargs)
if kwargs['type'] == 'inverse':
if kwargs["type"] == "inverse":
if not wrapper.has_run:
wrapper.has_run = True
return f(*args, **kwargs)