🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-21 10:15:57 +00:00
committed by Nicola Demo
parent e516e779f9
commit c2529d325a
3 changed files with 78 additions and 63 deletions

View File

@@ -1,12 +1,12 @@
__all__ = [ __all__ = [
'FeedForward', "FeedForward",
'ResidualFeedForward', "ResidualFeedForward",
'MultiFeedForward', "MultiFeedForward",
'DeepONet', "DeepONet",
'MIONet', "MIONet",
'FNO', "FNO",
'FourierIntegralKernel', "FourierIntegralKernel",
'KernelNeuralOperator' "KernelNeuralOperator",
] ]
from .feed_forward import FeedForward, ResidualFeedForward from .feed_forward import FeedForward, ResidualFeedForward

View File

@@ -44,6 +44,7 @@ class KernelNeuralOperator(torch.nn.Module):
spaces with applications to PDEs*. Journal of Machine Learning spaces with applications to PDEs*. Journal of Machine Learning
Research, 24(89), 1-97. Research, 24(89), 1-97.
""" """
def __init__(self, lifting_operator, integral_kernels, projection_operator): def __init__(self, lifting_operator, integral_kernels, projection_operator):
""" """
:param torch.nn.Module lifting_operator: The lifting operator :param torch.nn.Module lifting_operator: The lifting operator

View File

@@ -28,7 +28,9 @@ class FourierIntegralKernel(torch.nn.Module):
DOI: `arXiv preprint arXiv:2010.08895. DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_ <https://arxiv.org/abs/2010.08895>`_
""" """
def __init__(self,
def __init__(
self,
input_numb_fields, input_numb_fields,
output_numb_fields, output_numb_fields,
n_modes, n_modes,
@@ -38,7 +40,8 @@ class FourierIntegralKernel(torch.nn.Module):
inner_size=20, inner_size=20,
n_layers=2, n_layers=2,
func=nn.Tanh, func=nn.Tanh,
layers=None): layers=None,
):
""" """
:param int input_numb_fields: Number of input fields. :param int input_numb_fields: Number of input fields.
:param int output_numb_fields: Number of output fields. :param int output_numb_fields: Number of output fields.
@@ -69,7 +72,8 @@ class FourierIntegralKernel(torch.nn.Module):
if not isinstance(n_modes, (list, tuple, int)): if not isinstance(n_modes, (list, tuple, int)):
raise ValueError( raise ValueError(
"n_modes must be a int or list or tuple of valid modes." "n_modes must be a int or list or tuple of valid modes."
" More information on the official documentation.") " More information on the official documentation."
)
# assign padding # assign padding
self._padding = padding self._padding = padding
@@ -82,9 +86,7 @@ class FourierIntegralKernel(torch.nn.Module):
elif dimensions == 3: elif dimensions == 3:
fourier_layer = FourierBlock3D fourier_layer = FourierBlock3D
else: else:
raise NotImplementedError( raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
"FNO implemented only for 1D/2D/3D data."
)
# Here we build the FNO kernels by stacking Fourier Blocks # Here we build the FNO kernels by stacking Fourier Blocks
@@ -96,7 +98,8 @@ class FourierIntegralKernel(torch.nn.Module):
if isinstance(func, list): if isinstance(func, list):
if len(layers) != len(func): if len(layers) != len(func):
raise RuntimeError( raise RuntimeError(
'Uncosistent number of layers and functions.') "Uncosistent number of layers and functions."
)
_functions = func _functions = func
else: else:
_functions = [func for _ in range(len(layers) - 1)] _functions = [func for _ in range(len(layers) - 1)]
@@ -104,10 +107,12 @@ class FourierIntegralKernel(torch.nn.Module):
# 3. Assign modes functions for each FNO layer # 3. Assign modes functions for each FNO layer
if isinstance(n_modes, list): if isinstance(n_modes, list):
if all(isinstance(i, list) if all(isinstance(i, list) for i in n_modes) and len(layers) != len(
for i in n_modes) and len(layers) != len(n_modes): n_modes
):
raise RuntimeError( raise RuntimeError(
"Uncosistent number of layers and functions.") "Uncosistent number of layers and functions."
)
elif all(isinstance(i, int) for i in n_modes): elif all(isinstance(i, int) for i in n_modes):
n_modes = [n_modes] * len(layers) n_modes = [n_modes] * len(layers)
else: else:
@@ -118,10 +123,13 @@ class FourierIntegralKernel(torch.nn.Module):
tmp_layers = [input_numb_fields] + layers + [output_numb_fields] tmp_layers = [input_numb_fields] + layers + [output_numb_fields]
for i in range(len(layers)): for i in range(len(layers)):
_layers.append( _layers.append(
fourier_layer(input_numb_fields=tmp_layers[i], fourier_layer(
input_numb_fields=tmp_layers[i],
output_numb_fields=tmp_layers[i + 1], output_numb_fields=tmp_layers[i + 1],
n_modes=n_modes[i], n_modes=n_modes[i],
activation=_functions[i])) activation=_functions[i],
)
)
self._layers = nn.Sequential(*_layers) self._layers = nn.Sequential(*_layers)
# 5. Padding values for spectral conv # 5. Padding values for spectral conv
@@ -152,8 +160,8 @@ class FourierIntegralKernel(torch.nn.Module):
""" """
if isinstance(x, LabelTensor): # TODO remove when Network is fixed if isinstance(x, LabelTensor): # TODO remove when Network is fixed
warnings.warn( warnings.warn(
'LabelTensor passed as input is not allowed,' "LabelTensor passed as input is not allowed,"
' casting LabelTensor to Torch.Tensor' " casting LabelTensor to Torch.Tensor"
) )
x = x.as_subclass(torch.Tensor) x = x.as_subclass(torch.Tensor)
# permuting the input [batch, channels, x, y, ...] # permuting the input [batch, channels, x, y, ...]
@@ -196,7 +204,9 @@ class FNO(KernelNeuralOperator):
DOI: `arXiv preprint arXiv:2010.08895. DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_ <https://arxiv.org/abs/2010.08895>`_
""" """
def __init__(self,
def __init__(
self,
lifting_net, lifting_net,
projecting_net, projecting_net,
n_modes, n_modes,
@@ -206,7 +216,8 @@ class FNO(KernelNeuralOperator):
inner_size=20, inner_size=20,
n_layers=2, n_layers=2,
func=nn.Tanh, func=nn.Tanh,
layers=None): layers=None,
):
""" """
:param torch.nn.Module lifting_net: The neural network for lifting :param torch.nn.Module lifting_net: The neural network for lifting
the input. the input.
@@ -222,13 +233,14 @@ class FNO(KernelNeuralOperator):
:param list[int] layers: List of layer sizes, defaults to None. :param list[int] layers: List of layer sizes, defaults to None.
""" """
lifting_operator_out = lifting_net( lifting_operator_out = lifting_net(
torch.rand(size=next(lifting_net.parameters()).size())).shape[-1] torch.rand(size=next(lifting_net.parameters()).size())
super().__init__(lifting_operator=lifting_net, ).shape[-1]
super().__init__(
lifting_operator=lifting_net,
projection_operator=projecting_net, projection_operator=projecting_net,
integral_kernels=FourierIntegralKernel( integral_kernels=FourierIntegralKernel(
input_numb_fields=lifting_operator_out, input_numb_fields=lifting_operator_out,
output_numb_fields=next( output_numb_fields=next(projecting_net.parameters()).size(),
projecting_net.parameters()).size(),
n_modes=n_modes, n_modes=n_modes,
dimensions=dimensions, dimensions=dimensions,
padding=padding, padding=padding,
@@ -236,7 +248,9 @@ class FNO(KernelNeuralOperator):
inner_size=inner_size, inner_size=inner_size,
n_layers=n_layers, n_layers=n_layers,
func=func, func=func,
layers=layers)) layers=layers,
),
)
def forward(self, x): def forward(self, x):
""" """