Fix Codacy Warnings (#477)
--------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
e3790e049a
commit
4177bfbb50
@@ -2,10 +2,10 @@
|
||||
Fourier Neural Operator Module.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..label_tensor import LabelTensor
|
||||
import warnings
|
||||
import torch
|
||||
from torch import nn
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from .block.fourier_block import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||
from .kernel_neural_operator import KernelNeuralOperator
|
||||
@@ -57,36 +57,22 @@ class FourierIntegralKernel(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
# check type consistency
|
||||
check_consistency(dimensions, int)
|
||||
check_consistency(padding, int)
|
||||
check_consistency(padding_type, str)
|
||||
check_consistency(inner_size, int)
|
||||
check_consistency(n_layers, int)
|
||||
check_consistency(func, nn.Module, subclass=True)
|
||||
|
||||
if layers is not None:
|
||||
if isinstance(layers, (tuple, list)):
|
||||
check_consistency(layers, int)
|
||||
else:
|
||||
raise ValueError("layers must be tuple or list of int.")
|
||||
if not isinstance(n_modes, (list, tuple, int)):
|
||||
raise ValueError(
|
||||
"n_modes must be a int or list or tuple of valid modes."
|
||||
" More information on the official documentation."
|
||||
)
|
||||
self._check_consistency(
|
||||
dimensions,
|
||||
padding,
|
||||
padding_type,
|
||||
inner_size,
|
||||
n_layers,
|
||||
func,
|
||||
layers,
|
||||
n_modes,
|
||||
)
|
||||
|
||||
# assign padding
|
||||
self._padding = padding
|
||||
|
||||
# initialize fourier layer for each dimension
|
||||
if dimensions == 1:
|
||||
fourier_layer = FourierBlock1D
|
||||
elif dimensions == 2:
|
||||
fourier_layer = FourierBlock2D
|
||||
elif dimensions == 3:
|
||||
fourier_layer = FourierBlock3D
|
||||
else:
|
||||
raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
|
||||
fourier_layer = self._get_fourier_block(dimensions)
|
||||
|
||||
# Here we build the FNO kernels by stacking Fourier Blocks
|
||||
|
||||
@@ -113,24 +99,24 @@ class FourierIntegralKernel(torch.nn.Module):
|
||||
raise RuntimeError(
|
||||
"Uncosistent number of layers and functions."
|
||||
)
|
||||
elif all(isinstance(i, int) for i in n_modes):
|
||||
if all(isinstance(i, int) for i in n_modes):
|
||||
n_modes = [n_modes] * len(layers)
|
||||
else:
|
||||
n_modes = [n_modes] * len(layers)
|
||||
|
||||
# 4. Build the FNO network
|
||||
_layers = []
|
||||
tmp_layers = [input_numb_fields] + layers + [output_numb_fields]
|
||||
for i in range(len(layers)):
|
||||
_layers.append(
|
||||
self._layers = nn.Sequential(
|
||||
*[
|
||||
fourier_layer(
|
||||
input_numb_fields=tmp_layers[i],
|
||||
output_numb_fields=tmp_layers[i + 1],
|
||||
n_modes=n_modes[i],
|
||||
activation=_functions[i],
|
||||
)
|
||||
)
|
||||
self._layers = nn.Sequential(*_layers)
|
||||
for i in range(len(layers))
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Padding values for spectral conv
|
||||
if isinstance(padding, int):
|
||||
@@ -158,14 +144,14 @@ class FourierIntegralKernel(torch.nn.Module):
|
||||
:return: The output tensor obtained from the kernels convolution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if isinstance(x, LabelTensor): # TODO remove when Network is fixed
|
||||
if isinstance(x, LabelTensor):
|
||||
warnings.warn(
|
||||
"LabelTensor passed as input is not allowed,"
|
||||
" casting LabelTensor to Torch.Tensor"
|
||||
)
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
# permuting the input [batch, channels, x, y, ...]
|
||||
permutation_idx = [0, x.ndim - 1, *[i for i in range(1, x.ndim - 1)]]
|
||||
permutation_idx = [0, x.ndim - 1, *list(range(1, x.ndim - 1))]
|
||||
x = x.permute(permutation_idx)
|
||||
|
||||
# padding the input
|
||||
@@ -179,11 +165,50 @@ class FourierIntegralKernel(torch.nn.Module):
|
||||
x = x[idxs]
|
||||
|
||||
# permuting back [batch, x, y, ..., channels]
|
||||
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
|
||||
permutation_idx = [0, *list(range(2, x.ndim)), 1]
|
||||
x = x.permute(permutation_idx)
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _check_consistency(
|
||||
dimensions,
|
||||
padding,
|
||||
padding_type,
|
||||
inner_size,
|
||||
n_layers,
|
||||
func,
|
||||
layers,
|
||||
n_modes,
|
||||
):
|
||||
check_consistency(dimensions, int)
|
||||
check_consistency(padding, int)
|
||||
check_consistency(padding_type, str)
|
||||
check_consistency(inner_size, int)
|
||||
check_consistency(n_layers, int)
|
||||
check_consistency(func, nn.Module, subclass=True)
|
||||
|
||||
if layers is not None:
|
||||
if isinstance(layers, (tuple, list)):
|
||||
check_consistency(layers, int)
|
||||
else:
|
||||
raise ValueError("layers must be tuple or list of int.")
|
||||
if not isinstance(n_modes, (list, tuple, int)):
|
||||
raise ValueError(
|
||||
"n_modes must be a int or list or tuple of valid modes."
|
||||
" More information on the official documentation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_fourier_block(dimensions):
|
||||
if dimensions == 1:
|
||||
return FourierBlock1D
|
||||
if dimensions == 2:
|
||||
return FourierBlock2D
|
||||
if dimensions == 3:
|
||||
return FourierBlock3D
|
||||
raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
|
||||
|
||||
|
||||
class FNO(KernelNeuralOperator):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user