🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,18 +1,22 @@
|
||||
__all__ = [
|
||||
'ContinuousConvBlock',
|
||||
'ResidualBlock',
|
||||
'EnhancedLinear',
|
||||
'SpectralConvBlock1D',
|
||||
'SpectralConvBlock2D',
|
||||
'SpectralConvBlock3D',
|
||||
'FourierBlock1D',
|
||||
'FourierBlock2D',
|
||||
'FourierBlock3D',
|
||||
'PODLayer'
|
||||
"ContinuousConvBlock",
|
||||
"ResidualBlock",
|
||||
"EnhancedLinear",
|
||||
"SpectralConvBlock1D",
|
||||
"SpectralConvBlock2D",
|
||||
"SpectralConvBlock3D",
|
||||
"FourierBlock1D",
|
||||
"FourierBlock2D",
|
||||
"FourierBlock3D",
|
||||
"PODLayer",
|
||||
]
|
||||
|
||||
from .convolution_2d import ContinuousConvBlock
|
||||
from .residual import ResidualBlock, EnhancedLinear
|
||||
from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
|
||||
from .spectral import (
|
||||
SpectralConvBlock1D,
|
||||
SpectralConvBlock2D,
|
||||
SpectralConvBlock3D,
|
||||
)
|
||||
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||
from .pod import PODLayer
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for Base Continuous Convolution class."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from .stride import Stride
|
||||
@@ -10,14 +11,16 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
Abstract class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_numb_field,
|
||||
output_numb_field,
|
||||
filter_dim,
|
||||
stride,
|
||||
model=None,
|
||||
optimize=False,
|
||||
no_overlap=False):
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_field,
|
||||
output_numb_field,
|
||||
filter_dim,
|
||||
stride,
|
||||
model=None,
|
||||
optimize=False,
|
||||
no_overlap=False,
|
||||
):
|
||||
"""
|
||||
Base Class for Continuous Convolution.
|
||||
|
||||
@@ -75,43 +78,44 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
if isinstance(input_numb_field, int):
|
||||
self._input_numb_field = input_numb_field
|
||||
else:
|
||||
raise ValueError('input_numb_field must be int.')
|
||||
raise ValueError("input_numb_field must be int.")
|
||||
|
||||
if isinstance(output_numb_field, int):
|
||||
self._output_numb_field = output_numb_field
|
||||
else:
|
||||
raise ValueError('input_numb_field must be int.')
|
||||
raise ValueError("input_numb_field must be int.")
|
||||
|
||||
if isinstance(filter_dim, (tuple, list)):
|
||||
vect = filter_dim
|
||||
else:
|
||||
raise ValueError('filter_dim must be tuple or list.')
|
||||
raise ValueError("filter_dim must be tuple or list.")
|
||||
vect = torch.tensor(vect)
|
||||
self.register_buffer("_dim", vect, persistent=False)
|
||||
|
||||
if isinstance(stride, dict):
|
||||
self._stride = Stride(stride)
|
||||
else:
|
||||
raise ValueError('stride must be dictionary.')
|
||||
raise ValueError("stride must be dictionary.")
|
||||
|
||||
self._net = model
|
||||
|
||||
if isinstance(optimize, bool):
|
||||
self._optimize = optimize
|
||||
else:
|
||||
raise ValueError('optimize must be bool.')
|
||||
raise ValueError("optimize must be bool.")
|
||||
|
||||
# choosing how to initialize based on optimization
|
||||
if self._optimize:
|
||||
# optimizing decorator ensure the function is called
|
||||
# just once
|
||||
self._choose_initialization = optimizing(
|
||||
self._initialize_convolution)
|
||||
self._initialize_convolution
|
||||
)
|
||||
else:
|
||||
self._choose_initialization = self._initialize_convolution
|
||||
|
||||
if not isinstance(no_overlap, bool):
|
||||
raise ValueError('no_overlap must be bool.')
|
||||
raise ValueError("no_overlap must be bool.")
|
||||
|
||||
if no_overlap:
|
||||
raise NotImplementedError
|
||||
@@ -125,11 +129,13 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
super().__init__()
|
||||
assert isinstance(input_dim, int)
|
||||
assert isinstance(output_dim, int)
|
||||
self._model = torch.nn.Sequential(torch.nn.Linear(input_dim, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, output_dim))
|
||||
self._model = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_dim, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self._model(x)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for Continuous Convolution class"""
|
||||
|
||||
from .convolution import BaseContinuousConv
|
||||
from .utils_convolution import check_point, map_points_
|
||||
from .integral import Integral
|
||||
@@ -31,14 +32,16 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_numb_field,
|
||||
output_numb_field,
|
||||
filter_dim,
|
||||
stride,
|
||||
model=None,
|
||||
optimize=False,
|
||||
no_overlap=False):
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_field,
|
||||
output_numb_field,
|
||||
filter_dim,
|
||||
stride,
|
||||
model=None,
|
||||
optimize=False,
|
||||
no_overlap=False,
|
||||
):
|
||||
"""
|
||||
:param input_numb_field: Number of fields :math:`N_{in}` in the input.
|
||||
:type input_numb_field: int
|
||||
@@ -112,16 +115,18 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
)
|
||||
"""
|
||||
|
||||
super().__init__(input_numb_field=input_numb_field,
|
||||
output_numb_field=output_numb_field,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
model=model,
|
||||
optimize=optimize,
|
||||
no_overlap=no_overlap)
|
||||
super().__init__(
|
||||
input_numb_field=input_numb_field,
|
||||
output_numb_field=output_numb_field,
|
||||
filter_dim=filter_dim,
|
||||
stride=stride,
|
||||
model=model,
|
||||
optimize=optimize,
|
||||
no_overlap=no_overlap,
|
||||
)
|
||||
|
||||
# integral routine
|
||||
self._integral = Integral('discrete')
|
||||
self._integral = Integral("discrete")
|
||||
|
||||
# create the network
|
||||
self._net = self._spawn_networks(model)
|
||||
@@ -146,15 +151,18 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
nets.append(tmp)
|
||||
else:
|
||||
if not isinstance(model, object):
|
||||
raise ValueError("Expected a python class inheriting"
|
||||
" from torch.nn.Module")
|
||||
raise ValueError(
|
||||
"Expected a python class inheriting" " from torch.nn.Module"
|
||||
)
|
||||
|
||||
for _ in range(self._input_numb_field * self._output_numb_field):
|
||||
tmp = model()
|
||||
if not isinstance(tmp, torch.nn.Module):
|
||||
raise ValueError("The python class must be inherited from"
|
||||
" torch.nn.Module. See the docstring for"
|
||||
" an example.")
|
||||
raise ValueError(
|
||||
"The python class must be inherited from"
|
||||
" torch.nn.Module. See the docstring for"
|
||||
" an example."
|
||||
)
|
||||
nets.append(tmp)
|
||||
|
||||
return torch.nn.ModuleList(nets)
|
||||
@@ -232,11 +240,17 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
number_points = len(self._stride)
|
||||
|
||||
# initialize the grid
|
||||
grid = torch.zeros(size=(X.shape[0], self._output_numb_field,
|
||||
number_points, filter_dim + 1),
|
||||
device=X.device,
|
||||
dtype=X.dtype)
|
||||
grid[..., :-1] = (self._stride + self._dim * 0.5)
|
||||
grid = torch.zeros(
|
||||
size=(
|
||||
X.shape[0],
|
||||
self._output_numb_field,
|
||||
number_points,
|
||||
filter_dim + 1,
|
||||
),
|
||||
device=X.device,
|
||||
dtype=X.dtype,
|
||||
)
|
||||
grid[..., :-1] = self._stride + self._dim * 0.5
|
||||
|
||||
# saving the grid
|
||||
self._grid = grid.detach()
|
||||
@@ -269,14 +283,14 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
"""
|
||||
# choose the type of convolution
|
||||
if type == 'forward':
|
||||
if type == "forward":
|
||||
return self._make_grid_forward(X)
|
||||
elif type == 'inverse':
|
||||
elif type == "inverse":
|
||||
self._make_grid_transpose(X)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def _initialize_convolution(self, X, type='forward'):
|
||||
def _initialize_convolution(self, X, type="forward"):
|
||||
"""
|
||||
Private method to intialize the convolution.
|
||||
The convolution is initialized by setting a grid and
|
||||
@@ -307,10 +321,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='forward')
|
||||
self._choose_initialization(X, type="forward")
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'forward')
|
||||
self._initialize_convolution(X, "forward")
|
||||
|
||||
# create convolutional array
|
||||
conv = self._grid.clone().detach()
|
||||
@@ -322,7 +336,8 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# extract mapped points
|
||||
stacked_input, indeces_channels = self._extract_mapped_points(
|
||||
batch_idx, self._index, x)
|
||||
batch_idx, self._index, x
|
||||
)
|
||||
|
||||
# compute the convolution
|
||||
|
||||
@@ -339,9 +354,11 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
# calculate filter value
|
||||
staked_output = net(single_channel_input[..., :-1])
|
||||
# perform integral for all strides in one field
|
||||
integral = self._integral(staked_output,
|
||||
single_channel_input[..., -1],
|
||||
indeces_channels[idx])
|
||||
integral = self._integral(
|
||||
staked_output,
|
||||
single_channel_input[..., -1],
|
||||
indeces_channels[idx],
|
||||
)
|
||||
res_tmp.append(integral)
|
||||
|
||||
# stacking integral results
|
||||
@@ -349,9 +366,9 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# sum filters (for each input fields) in groups
|
||||
# for different ouput fields
|
||||
conv[batch_idx, ...,
|
||||
-1] = res_tmp.reshape(self._output_numb_field,
|
||||
self._input_numb_field, -1).sum(1)
|
||||
conv[batch_idx, ..., -1] = res_tmp.reshape(
|
||||
self._output_numb_field, self._input_numb_field, -1
|
||||
).sum(1)
|
||||
return conv
|
||||
|
||||
def transpose_no_overlap(self, integrals, X):
|
||||
@@ -382,10 +399,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='inverse')
|
||||
self._choose_initialization(X, type="inverse")
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'inverse')
|
||||
self._initialize_convolution(X, "inverse")
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
@@ -398,7 +415,8 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# extract mapped points
|
||||
stacked_input, indeces_channels = self._extract_mapped_points(
|
||||
batch_idx, self._index, x)
|
||||
batch_idx, self._index, x
|
||||
)
|
||||
|
||||
# compute the transpose convolution
|
||||
|
||||
@@ -414,8 +432,9 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
# extract input for each field
|
||||
single_channel_input = stacked_input[idx]
|
||||
rep_idx = torch.tensor(indeces_channels[idx])
|
||||
integral = integrals[batch_idx,
|
||||
idx_in, :].repeat_interleave(rep_idx)
|
||||
integral = integrals[batch_idx, idx_in, :].repeat_interleave(
|
||||
rep_idx
|
||||
)
|
||||
# extract filter
|
||||
net = self._net[idx_conv]
|
||||
# perform transpose convolution for all strides in one field
|
||||
@@ -426,9 +445,11 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
# stacking integral results and sum
|
||||
# filters (for each input fields) in groups
|
||||
# for different output fields
|
||||
res_tmp = torch.stack(res_tmp).reshape(self._input_numb_field,
|
||||
self._output_numb_field,
|
||||
-1).sum(0)
|
||||
res_tmp = (
|
||||
torch.stack(res_tmp)
|
||||
.reshape(self._input_numb_field, self._output_numb_field, -1)
|
||||
.sum(0)
|
||||
)
|
||||
conv_transposed[batch_idx, ..., -1] = res_tmp
|
||||
|
||||
return conv_transposed
|
||||
@@ -460,10 +481,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# initialize convolution
|
||||
if self.training: # we choose what to do based on optimization
|
||||
self._choose_initialization(X, type='inverse')
|
||||
self._choose_initialization(X, type="inverse")
|
||||
|
||||
else: # we always initialize on testing
|
||||
self._initialize_convolution(X, 'inverse')
|
||||
self._initialize_convolution(X, "inverse")
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
@@ -479,11 +500,14 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# accumulator for the convolution on different batches
|
||||
accumulator_batch = torch.zeros(
|
||||
size=(self._grid_transpose.shape[1],
|
||||
self._grid_transpose.shape[2]),
|
||||
size=(
|
||||
self._grid_transpose.shape[1],
|
||||
self._grid_transpose.shape[2],
|
||||
),
|
||||
requires_grad=True,
|
||||
device=X.device,
|
||||
dtype=X.dtype).clone()
|
||||
dtype=X.dtype,
|
||||
).clone()
|
||||
|
||||
for stride_idx, current_stride in enumerate(self._stride):
|
||||
# indeces of points falling into filter range
|
||||
@@ -522,9 +546,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
staked_output = net(nn_input_pts[idx_channel_out])
|
||||
|
||||
# perform integral for all strides in one field
|
||||
integral = staked_output * integrals[batch_idx,
|
||||
idx_channel_in,
|
||||
stride_idx]
|
||||
integral = (
|
||||
staked_output
|
||||
* integrals[batch_idx, idx_channel_in, stride_idx]
|
||||
)
|
||||
# append results
|
||||
res_tmp.append(integral.flatten())
|
||||
|
||||
@@ -532,7 +557,7 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
channel_sum = []
|
||||
start = 0
|
||||
for _ in range(self._output_numb_field):
|
||||
tmp = res_tmp[start:start + self._input_numb_field]
|
||||
tmp = res_tmp[start : start + self._input_numb_field]
|
||||
tmp = torch.vstack(tmp).sum(dim=0)
|
||||
channel_sum.append(tmp)
|
||||
start += self._input_numb_field
|
||||
|
||||
@@ -2,14 +2,18 @@ import torch
|
||||
import torch.nn as nn
|
||||
from ...utils import check_consistency
|
||||
|
||||
from pina.model.layers import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D
|
||||
from pina.model.layers import (
|
||||
SpectralConvBlock1D,
|
||||
SpectralConvBlock2D,
|
||||
SpectralConvBlock3D,
|
||||
)
|
||||
|
||||
|
||||
class FourierBlock1D(nn.Module):
|
||||
"""
|
||||
Fourier block implementation for three dimensional
|
||||
input tensor. The combination of Fourier blocks
|
||||
make up the Fourier Neural Operator
|
||||
make up the Fourier Neural Operator
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -21,11 +25,13 @@ class FourierBlock1D(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_numb_fields,
|
||||
output_numb_fields,
|
||||
n_modes,
|
||||
activation=torch.nn.Tanh):
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_fields,
|
||||
output_numb_fields,
|
||||
n_modes,
|
||||
activation=torch.nn.Tanh,
|
||||
):
|
||||
super().__init__()
|
||||
"""
|
||||
PINA implementation of Fourier block one dimension. The module computes
|
||||
@@ -51,17 +57,18 @@ class FourierBlock1D(nn.Module):
|
||||
self._spectral_conv = SpectralConvBlock1D(
|
||||
input_numb_fields=input_numb_fields,
|
||||
output_numb_fields=output_numb_fields,
|
||||
n_modes=n_modes)
|
||||
n_modes=n_modes,
|
||||
)
|
||||
self._activation = activation()
|
||||
self._linear = nn.Conv1d(input_numb_fields, output_numb_fields, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward computation for Fourier Block. It performs a spectral
|
||||
Forward computation for Fourier Block. It performs a spectral
|
||||
convolution and a linear transformation of the input and sum the
|
||||
results.
|
||||
|
||||
:param x: The input tensor for fourier block, expect of size
|
||||
:param x: The input tensor for fourier block, expect of size
|
||||
``[batch, input_numb_fields, x]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
@@ -75,7 +82,7 @@ class FourierBlock2D(nn.Module):
|
||||
"""
|
||||
Fourier block implementation for two dimensional
|
||||
input tensor. The combination of Fourier blocks
|
||||
make up the Fourier Neural Operator
|
||||
make up the Fourier Neural Operator
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -87,18 +94,20 @@ class FourierBlock2D(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_numb_fields,
|
||||
output_numb_fields,
|
||||
n_modes,
|
||||
activation=torch.nn.Tanh):
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_fields,
|
||||
output_numb_fields,
|
||||
n_modes,
|
||||
activation=torch.nn.Tanh,
|
||||
):
|
||||
"""
|
||||
PINA implementation of Fourier block two dimensions. The module computes
|
||||
the spectral convolution of the input with a linear kernel in the
|
||||
fourier space, and then it maps the input back to the physical
|
||||
space. The output is then added to a Linear tranformation of the
|
||||
input in the physical space. Finally an activation function is
|
||||
applied to the output.
|
||||
applied to the output.
|
||||
|
||||
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
|
||||
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
|
||||
@@ -118,17 +127,18 @@ class FourierBlock2D(nn.Module):
|
||||
self._spectral_conv = SpectralConvBlock2D(
|
||||
input_numb_fields=input_numb_fields,
|
||||
output_numb_fields=output_numb_fields,
|
||||
n_modes=n_modes)
|
||||
n_modes=n_modes,
|
||||
)
|
||||
self._activation = activation()
|
||||
self._linear = nn.Conv2d(input_numb_fields, output_numb_fields, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward computation for Fourier Block. It performs a spectral
|
||||
Forward computation for Fourier Block. It performs a spectral
|
||||
convolution and a linear transformation of the input and sum the
|
||||
results.
|
||||
|
||||
:param x: The input tensor for fourier block, expect of size
|
||||
:param x: The input tensor for fourier block, expect of size
|
||||
``[batch, input_numb_fields, x, y]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
@@ -142,7 +152,7 @@ class FourierBlock3D(nn.Module):
|
||||
"""
|
||||
Fourier block implementation for three dimensional
|
||||
input tensor. The combination of Fourier blocks
|
||||
make up the Fourier Neural Operator
|
||||
make up the Fourier Neural Operator
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -154,18 +164,20 @@ class FourierBlock3D(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_numb_fields,
|
||||
output_numb_fields,
|
||||
n_modes,
|
||||
activation=torch.nn.Tanh):
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_fields,
|
||||
output_numb_fields,
|
||||
n_modes,
|
||||
activation=torch.nn.Tanh,
|
||||
):
|
||||
"""
|
||||
PINA implementation of Fourier block three dimensions. The module computes
|
||||
the spectral convolution of the input with a linear kernel in the
|
||||
fourier space, and then it maps the input back to the physical
|
||||
space. The output is then added to a Linear tranformation of the
|
||||
input in the physical space. Finally an activation function is
|
||||
applied to the output.
|
||||
applied to the output.
|
||||
|
||||
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
|
||||
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
|
||||
@@ -186,17 +198,18 @@ class FourierBlock3D(nn.Module):
|
||||
self._spectral_conv = SpectralConvBlock3D(
|
||||
input_numb_fields=input_numb_fields,
|
||||
output_numb_fields=output_numb_fields,
|
||||
n_modes=n_modes)
|
||||
n_modes=n_modes,
|
||||
)
|
||||
self._activation = activation()
|
||||
self._linear = nn.Conv3d(input_numb_fields, output_numb_fields, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward computation for Fourier Block. It performs a spectral
|
||||
Forward computation for Fourier Block. It performs a spectral
|
||||
convolution and a linear transformation of the input and sum the
|
||||
results.
|
||||
|
||||
:param x: The input tensor for fourier block, expect of size
|
||||
:param x: The input tensor for fourier block, expect of size
|
||||
``[batch, input_numb_fields, x, y, z]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
|
||||
@@ -10,9 +10,9 @@ class Integral(object):
|
||||
:type param: string
|
||||
"""
|
||||
|
||||
if param == 'discrete':
|
||||
if param == "discrete":
|
||||
self.make_integral = self.integral_param_disc
|
||||
elif param == 'continuous':
|
||||
elif param == "continuous":
|
||||
self.make_integral = self.integral_param_cont
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for Base Continuous Convolution class."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from .stride import Stride
|
||||
@@ -38,17 +39,17 @@ class PODLayer(torch.nn.Module):
|
||||
:rtype: int
|
||||
"""
|
||||
return self._rank
|
||||
|
||||
|
||||
@rank.setter
|
||||
def rank(self, value):
|
||||
if value < 1 or not isinstance(value, int):
|
||||
raise ValueError('The rank must be positive integer')
|
||||
raise ValueError("The rank must be positive integer")
|
||||
|
||||
self._rank = value
|
||||
|
||||
@property
|
||||
def basis(self):
|
||||
"""
|
||||
"""
|
||||
The POD basis. It is a matrix whose columns are the first `self.rank` POD modes.
|
||||
|
||||
:rtype: torch.Tensor
|
||||
@@ -56,7 +57,7 @@ class PODLayer(torch.nn.Module):
|
||||
if self._basis is None:
|
||||
return None
|
||||
|
||||
return self._basis[:self.rank]
|
||||
return self._basis[: self.rank]
|
||||
|
||||
@property
|
||||
def scaler(self):
|
||||
@@ -67,10 +68,12 @@ class PODLayer(torch.nn.Module):
|
||||
:rtype: dict
|
||||
"""
|
||||
if self._scaler is None:
|
||||
return
|
||||
return
|
||||
|
||||
return {'mean': self._scaler['mean'][:self.rank],
|
||||
'std': self._scaler['std'][:self.rank]}
|
||||
return {
|
||||
"mean": self._scaler["mean"][: self.rank],
|
||||
"std": self._scaler["std"][: self.rank],
|
||||
}
|
||||
|
||||
@property
|
||||
def scale_coefficients(self):
|
||||
@@ -105,8 +108,9 @@ class PODLayer(torch.nn.Module):
|
||||
:param torch.Tensor coeffs: The coefficients to be scaled.
|
||||
"""
|
||||
self._scaler = {
|
||||
'std': torch.std(coeffs, dim=1),
|
||||
'mean': torch.mean(coeffs, dim=1)}
|
||||
"std": torch.std(coeffs, dim=1),
|
||||
"mean": torch.mean(coeffs, dim=1),
|
||||
}
|
||||
|
||||
def _fit_pod(self, X):
|
||||
"""
|
||||
@@ -114,7 +118,7 @@ class PODLayer(torch.nn.Module):
|
||||
|
||||
:param torch.Tensor X: The tensor to be reduced.
|
||||
"""
|
||||
if X.device.type == 'mps': # svd_lowrank not arailable for mps
|
||||
if X.device.type == "mps": # svd_lowrank not arailable for mps
|
||||
self._basis = torch.svd(X.T)[0].T
|
||||
else:
|
||||
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
|
||||
@@ -142,7 +146,8 @@ class PODLayer(torch.nn.Module):
|
||||
"""
|
||||
if self._basis is None:
|
||||
raise RuntimeError(
|
||||
'The POD layer needs to be fitted before being used.')
|
||||
"The POD layer needs to be fitted before being used."
|
||||
)
|
||||
|
||||
coeff = torch.matmul(self.basis, X.T)
|
||||
if coeff.ndim == 1:
|
||||
@@ -150,28 +155,29 @@ class PODLayer(torch.nn.Module):
|
||||
|
||||
coeff = coeff.T
|
||||
if self.__scale_coefficients:
|
||||
coeff = (coeff - self.scaler['mean']) / self.scaler['std']
|
||||
coeff = (coeff - self.scaler["mean"]) / self.scaler["std"]
|
||||
|
||||
return coeff
|
||||
|
||||
def expand(self, coeff):
|
||||
"""
|
||||
"""
|
||||
Expand the given coefficients to the original space. The POD layer needs
|
||||
to be fitted before being used.
|
||||
|
||||
|
||||
:param torch.Tensor coeff: The coefficients to be expanded.
|
||||
:return: The expanded tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if self._basis is None:
|
||||
raise RuntimeError(
|
||||
'The POD layer needs to be trained before being used.')
|
||||
"The POD layer needs to be trained before being used."
|
||||
)
|
||||
|
||||
if self.__scale_coefficients:
|
||||
coeff = coeff * self.scaler['std'] + self.scaler['mean']
|
||||
coeff = coeff * self.scaler["std"] + self.scaler["mean"]
|
||||
predicted = torch.matmul(self.basis.T, coeff.T).T
|
||||
|
||||
if predicted.ndim == 1:
|
||||
predicted = predicted.unsqueeze(0)
|
||||
|
||||
return predicted
|
||||
return predicted
|
||||
|
||||
@@ -16,18 +16,20 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
hidden_dim,
|
||||
spectral_norm=False,
|
||||
activation=torch.nn.ReLU()):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
hidden_dim,
|
||||
spectral_norm=False,
|
||||
activation=torch.nn.ReLU(),
|
||||
):
|
||||
"""
|
||||
Initializes the ResidualBlock module.
|
||||
|
||||
:param int input_dim: Dimension of the input to pass to the
|
||||
feedforward linear layer.
|
||||
:param int output_dim: Dimension of the output from the
|
||||
:param int output_dim: Dimension of the output from the
|
||||
residual layer.
|
||||
:param int hidden_dim: Hidden dimension for mapping the input
|
||||
(first block).
|
||||
@@ -82,6 +84,7 @@ class ResidualBlock(nn.Module):
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EnhancedLinear(torch.nn.Module):
|
||||
"""
|
||||
A wrapper class for enhancing a linear layer with activation and/or dropout.
|
||||
@@ -132,8 +135,9 @@ class EnhancedLinear(torch.nn.Module):
|
||||
self._model = torch.nn.Sequential(layer, self._drop(dropout))
|
||||
|
||||
elif (dropout is not None) and (activation is not None):
|
||||
self._model = torch.nn.Sequential(layer, activation,
|
||||
self._drop(dropout))
|
||||
self._model = torch.nn.Sequential(
|
||||
layer, activation, self._drop(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
||||
@@ -37,18 +37,23 @@ class SpectralConvBlock1D(nn.Module):
|
||||
self._output_channels = output_numb_fields
|
||||
|
||||
# scaling factor
|
||||
scale = (1. / (self._input_channels * self._output_channels))
|
||||
self._weights = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes,
|
||||
dtype=torch.cfloat))
|
||||
scale = 1.0 / (self._input_channels * self._output_channels)
|
||||
self._weights = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes,
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
|
||||
def _compute_mult1d(self, input, weights):
|
||||
"""
|
||||
Compute the matrix multiplication of the input
|
||||
with the linear kernel weights.
|
||||
|
||||
:param input: The input tensor, expect of size
|
||||
:param input: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x]``.
|
||||
:type input: torch.Tensor
|
||||
:param weights: The kernel weights, expect of
|
||||
@@ -64,7 +69,7 @@ class SpectralConvBlock1D(nn.Module):
|
||||
"""
|
||||
Forward computation for Spectral Convolution.
|
||||
|
||||
:param x: The input tensor, expect of size
|
||||
:param x: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
@@ -77,13 +82,16 @@ class SpectralConvBlock1D(nn.Module):
|
||||
x_ft = torch.fft.rfft(x)
|
||||
|
||||
# Multiply relevant Fourier modes
|
||||
out_ft = torch.zeros(batch_size,
|
||||
self._output_channels,
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat)
|
||||
out_ft[:, :, :self._modes] = self._compute_mult1d(
|
||||
x_ft[:, :, :self._modes], self._weights)
|
||||
out_ft = torch.zeros(
|
||||
batch_size,
|
||||
self._output_channels,
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
out_ft[:, :, : self._modes] = self._compute_mult1d(
|
||||
x_ft[:, :, : self._modes], self._weights
|
||||
)
|
||||
|
||||
# Return to physical space
|
||||
return torch.fft.irfft(out_ft, n=x.size(-1))
|
||||
@@ -119,17 +127,19 @@ class SpectralConvBlock2D(nn.Module):
|
||||
if isinstance(n_modes, (tuple, list)):
|
||||
if len(n_modes) != 2:
|
||||
raise ValueError(
|
||||
'Expected n_modes to be a list or tuple of len two, '
|
||||
'with each entry corresponding to the number of modes '
|
||||
'for each dimension ')
|
||||
"Expected n_modes to be a list or tuple of len two, "
|
||||
"with each entry corresponding to the number of modes "
|
||||
"for each dimension "
|
||||
)
|
||||
elif isinstance(n_modes, int):
|
||||
n_modes = [n_modes] * 2
|
||||
else:
|
||||
raise ValueError(
|
||||
'Expected n_modes to be a list or tuple of len two, '
|
||||
'with each entry corresponding to the number of modes '
|
||||
'for each dimension; or an int value representing the '
|
||||
'number of modes for all dimensions')
|
||||
"Expected n_modes to be a list or tuple of len two, "
|
||||
"with each entry corresponding to the number of modes "
|
||||
"for each dimension; or an int value representing the "
|
||||
"number of modes for all dimensions"
|
||||
)
|
||||
|
||||
# assign variables
|
||||
self._modes = n_modes
|
||||
@@ -137,24 +147,34 @@ class SpectralConvBlock2D(nn.Module):
|
||||
self._output_channels = output_numb_fields
|
||||
|
||||
# scaling factor
|
||||
scale = (1. / (self._input_channels * self._output_channels))
|
||||
self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
dtype=torch.cfloat))
|
||||
self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
dtype=torch.cfloat))
|
||||
scale = 1.0 / (self._input_channels * self._output_channels)
|
||||
self._weights1 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
self._weights2 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
|
||||
def _compute_mult2d(self, input, weights):
|
||||
"""
|
||||
Compute the matrix multiplication of the input
|
||||
with the linear kernel weights.
|
||||
|
||||
:param input: The input tensor, expect of size
|
||||
:param input: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x, y]``.
|
||||
:type input: torch.Tensor
|
||||
:param weights: The kernel weights, expect of
|
||||
@@ -170,7 +190,7 @@ class SpectralConvBlock2D(nn.Module):
|
||||
"""
|
||||
Forward computation for Spectral Convolution.
|
||||
|
||||
:param x: The input tensor, expect of size
|
||||
:param x: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x, y]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
@@ -184,16 +204,22 @@ class SpectralConvBlock2D(nn.Module):
|
||||
x_ft = torch.fft.rfft2(x)
|
||||
|
||||
# Multiply relevant Fourier modes
|
||||
out_ft = torch.zeros(batch_size,
|
||||
self._output_channels,
|
||||
x.size(-2),
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat)
|
||||
out_ft[:, :, :self._modes[0], :self._modes[1]] = self._compute_mult2d(
|
||||
x_ft[:, :, :self._modes[0], :self._modes[1]], self._weights1)
|
||||
out_ft[:, :, -self._modes[0]:, :self._modes[1]:] = self._compute_mult2d(
|
||||
x_ft[:, :, -self._modes[0]:, :self._modes[1]], self._weights2)
|
||||
out_ft = torch.zeros(
|
||||
batch_size,
|
||||
self._output_channels,
|
||||
x.size(-2),
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
out_ft[:, :, : self._modes[0], : self._modes[1]] = self._compute_mult2d(
|
||||
x_ft[:, :, : self._modes[0], : self._modes[1]], self._weights1
|
||||
)
|
||||
out_ft[:, :, -self._modes[0] :, : self._modes[1] :] = (
|
||||
self._compute_mult2d(
|
||||
x_ft[:, :, -self._modes[0] :, : self._modes[1]], self._weights2
|
||||
)
|
||||
)
|
||||
|
||||
# Return to physical space
|
||||
return torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
|
||||
@@ -230,17 +256,19 @@ class SpectralConvBlock3D(nn.Module):
|
||||
if isinstance(n_modes, (tuple, list)):
|
||||
if len(n_modes) != 3:
|
||||
raise ValueError(
|
||||
'Expected n_modes to be a list or tuple of len three, '
|
||||
'with each entry corresponding to the number of modes '
|
||||
'for each dimension ')
|
||||
"Expected n_modes to be a list or tuple of len three, "
|
||||
"with each entry corresponding to the number of modes "
|
||||
"for each dimension "
|
||||
)
|
||||
elif isinstance(n_modes, int):
|
||||
n_modes = [n_modes] * 3
|
||||
else:
|
||||
raise ValueError(
|
||||
'Expected n_modes to be a list or tuple of len three, '
|
||||
'with each entry corresponding to the number of modes '
|
||||
'for each dimension; or an int value representing the '
|
||||
'number of modes for all dimensions')
|
||||
"Expected n_modes to be a list or tuple of len three, "
|
||||
"with each entry corresponding to the number of modes "
|
||||
"for each dimension; or an int value representing the "
|
||||
"number of modes for all dimensions"
|
||||
)
|
||||
|
||||
# assign variables
|
||||
self._modes = n_modes
|
||||
@@ -248,38 +276,58 @@ class SpectralConvBlock3D(nn.Module):
|
||||
self._output_channels = output_numb_fields
|
||||
|
||||
# scaling factor
|
||||
scale = (1. / (self._input_channels * self._output_channels))
|
||||
self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat))
|
||||
self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat))
|
||||
self._weights3 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat))
|
||||
self._weights4 = nn.Parameter(scale * torch.rand(self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat))
|
||||
scale = 1.0 / (self._input_channels * self._output_channels)
|
||||
self._weights1 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
self._weights2 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
self._weights3 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
self._weights4 = nn.Parameter(
|
||||
scale
|
||||
* torch.rand(
|
||||
self._input_channels,
|
||||
self._output_channels,
|
||||
self._modes[0],
|
||||
self._modes[1],
|
||||
self._modes[2],
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
)
|
||||
|
||||
def _compute_mult3d(self, input, weights):
|
||||
"""
|
||||
Compute the matrix multiplication of the input
|
||||
with the linear kernel weights.
|
||||
|
||||
:param input: The input tensor, expect of size
|
||||
:param input: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x, y, z]``.
|
||||
:type input: torch.Tensor
|
||||
:param weights: The kernel weights, expect of
|
||||
@@ -295,7 +343,7 @@ class SpectralConvBlock3D(nn.Module):
|
||||
"""
|
||||
Forward computation for Spectral Convolution.
|
||||
|
||||
:param x: The input tensor, expect of size
|
||||
:param x: The input tensor, expect of size
|
||||
``[batch, input_numb_fields, x, y, z]``.
|
||||
:type x: torch.Tensor
|
||||
:return: The output tensor obtained from the
|
||||
@@ -309,13 +357,15 @@ class SpectralConvBlock3D(nn.Module):
|
||||
x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1])
|
||||
|
||||
# Multiply relevant Fourier modes
|
||||
out_ft = torch.zeros(batch_size,
|
||||
self._output_channels,
|
||||
x.size(-3),
|
||||
x.size(-2),
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat)
|
||||
out_ft = torch.zeros(
|
||||
batch_size,
|
||||
self._output_channels,
|
||||
x.size(-3),
|
||||
x.size(-2),
|
||||
x.size(-1) // 2 + 1,
|
||||
device=x.device,
|
||||
dtype=torch.cfloat,
|
||||
)
|
||||
|
||||
slice0 = (
|
||||
slice(None),
|
||||
|
||||
@@ -60,7 +60,8 @@ class Stride(object):
|
||||
|
||||
if seq_direction != seq_jumps:
|
||||
raise IndexError(
|
||||
"direction and jumps must have zero in the same index")
|
||||
"direction and jumps must have zero in the same index"
|
||||
)
|
||||
|
||||
if seq_jumps:
|
||||
for i in seq_jumps:
|
||||
|
||||
@@ -3,8 +3,9 @@ import torch
|
||||
|
||||
def check_point(x, current_stride, dim):
|
||||
max_stride = current_stride + dim
|
||||
indeces = torch.logical_and(x[..., :-1] < max_stride, x[..., :-1]
|
||||
>= current_stride).all(dim=-1)
|
||||
indeces = torch.logical_and(
|
||||
x[..., :-1] < max_stride, x[..., :-1] >= current_stride
|
||||
).all(dim=-1)
|
||||
return indeces
|
||||
|
||||
|
||||
@@ -32,12 +33,12 @@ def optimizing(f):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
if kwargs['type'] == 'forward':
|
||||
if kwargs["type"] == "forward":
|
||||
if not wrapper.has_run_inverse:
|
||||
wrapper.has_run_inverse = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if kwargs['type'] == 'inverse':
|
||||
if kwargs["type"] == "inverse":
|
||||
if not wrapper.has_run:
|
||||
wrapper.has_run = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user