🎨 Format Python code with psf/black
This commit is contained in:
@@ -12,7 +12,7 @@ class FNO(torch.nn.Module):
|
||||
|
||||
Fourier Neural Operator (FNO) is a general architecture for learning Operators.
|
||||
Unlike traditional machine learning methods FNO is designed to map
|
||||
entire functions to other functions. It can be trained both with
|
||||
entire functions to other functions. It can be trained both with
|
||||
Supervised learning strategies. FNO does global convolution by performing the
|
||||
operation on the Fourier space.
|
||||
|
||||
@@ -25,17 +25,19 @@ class FNO(torch.nn.Module):
|
||||
<https://arxiv.org/abs/2010.08895>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lifting_net,
|
||||
projecting_net,
|
||||
n_modes,
|
||||
dimensions=3,
|
||||
padding=8,
|
||||
padding_type="constant",
|
||||
inner_size=20,
|
||||
n_layers=2,
|
||||
func=nn.Tanh,
|
||||
layers=None):
|
||||
def __init__(
|
||||
self,
|
||||
lifting_net,
|
||||
projecting_net,
|
||||
n_modes,
|
||||
dimensions=3,
|
||||
padding=8,
|
||||
padding_type="constant",
|
||||
inner_size=20,
|
||||
n_layers=2,
|
||||
func=nn.Tanh,
|
||||
layers=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# check type consistency
|
||||
@@ -51,11 +53,12 @@ class FNO(torch.nn.Module):
|
||||
if isinstance(layers, (tuple, list)):
|
||||
check_consistency(layers, int)
|
||||
else:
|
||||
raise ValueError('layers must be tuple or list of int.')
|
||||
raise ValueError("layers must be tuple or list of int.")
|
||||
if not isinstance(n_modes, (list, tuple, int)):
|
||||
raise ValueError(
|
||||
'n_modes must be a int or list or tuple of valid modes.'
|
||||
' More information on the official documentation.')
|
||||
"n_modes must be a int or list or tuple of valid modes."
|
||||
" More information on the official documentation."
|
||||
)
|
||||
|
||||
# assign variables
|
||||
# TODO check input lifting net and input projecting net
|
||||
@@ -71,7 +74,7 @@ class FNO(torch.nn.Module):
|
||||
elif dimensions == 3:
|
||||
fourier_layer = FourierBlock3D
|
||||
else:
|
||||
raise NotImplementedError('FNO implemented only for 1D/2D/3D data.')
|
||||
raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
|
||||
|
||||
# Here we build the FNO by stacking Fourier Blocks
|
||||
|
||||
@@ -83,17 +86,20 @@ class FNO(torch.nn.Module):
|
||||
if isinstance(func, list):
|
||||
if len(layers) != len(func):
|
||||
raise RuntimeError(
|
||||
'Uncosistent number of layers and functions.')
|
||||
"Uncosistent number of layers and functions."
|
||||
)
|
||||
self._functions = func
|
||||
else:
|
||||
self._functions = [func for _ in range(len(layers))]
|
||||
|
||||
# 3. Assign modes functions for each FNO layer
|
||||
if isinstance(n_modes, list):
|
||||
if all(isinstance(i, list)
|
||||
for i in n_modes) and len(layers) != len(n_modes):
|
||||
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(
|
||||
n_modes
|
||||
):
|
||||
raise RuntimeError(
|
||||
'Uncosistent number of layers and functions.')
|
||||
"Uncosistent number of layers and functions."
|
||||
)
|
||||
elif all(isinstance(i, int) for i in n_modes):
|
||||
n_modes = [n_modes] * len(layers)
|
||||
else:
|
||||
@@ -109,10 +115,13 @@ class FNO(torch.nn.Module):
|
||||
self._layers = []
|
||||
for i in range(len(tmp_layers) - 1):
|
||||
self._layers.append(
|
||||
fourier_layer(input_numb_fields=tmp_layers[i],
|
||||
output_numb_fields=tmp_layers[i + 1],
|
||||
n_modes=n_modes[i],
|
||||
activation=self._functions[i]))
|
||||
fourier_layer(
|
||||
input_numb_fields=tmp_layers[i],
|
||||
output_numb_fields=tmp_layers[i + 1],
|
||||
n_modes=n_modes[i],
|
||||
activation=self._functions[i],
|
||||
)
|
||||
)
|
||||
self._layers = nn.Sequential(*self._layers)
|
||||
|
||||
# 5. Padding values for spectral conv
|
||||
@@ -139,8 +148,10 @@ class FNO(torch.nn.Module):
|
||||
:return: The output tensor obtained from the FNO.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if isinstance(x, LabelTensor): #TODO remove when Network is fixed
|
||||
warnings.warn('LabelTensor passed as input is not allowed, casting LabelTensor to Torch.Tensor')
|
||||
if isinstance(x, LabelTensor): # TODO remove when Network is fixed
|
||||
warnings.warn(
|
||||
"LabelTensor passed as input is not allowed, casting LabelTensor to Torch.Tensor"
|
||||
)
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
|
||||
# lifting the input in higher dimensional space
|
||||
|
||||
Reference in New Issue
Block a user