🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-21 10:15:57 +00:00
committed by Nicola Demo
parent e516e779f9
commit c2529d325a
3 changed files with 78 additions and 63 deletions

View File

@@ -1,12 +1,12 @@
__all__ = [ __all__ = [
'FeedForward', "FeedForward",
'ResidualFeedForward', "ResidualFeedForward",
'MultiFeedForward', "MultiFeedForward",
'DeepONet', "DeepONet",
'MIONet', "MIONet",
'FNO', "FNO",
'FourierIntegralKernel', "FourierIntegralKernel",
'KernelNeuralOperator' "KernelNeuralOperator",
] ]
from .feed_forward import FeedForward, ResidualFeedForward from .feed_forward import FeedForward, ResidualFeedForward

View File

@@ -44,6 +44,7 @@ class KernelNeuralOperator(torch.nn.Module):
spaces with applications to PDEs*. Journal of Machine Learning spaces with applications to PDEs*. Journal of Machine Learning
Research, 24(89), 1-97. Research, 24(89), 1-97.
""" """
def __init__(self, lifting_operator, integral_kernels, projection_operator): def __init__(self, lifting_operator, integral_kernels, projection_operator):
""" """
:param torch.nn.Module lifting_operator: The lifting operator :param torch.nn.Module lifting_operator: The lifting operator

View File

@@ -28,17 +28,20 @@ class FourierIntegralKernel(torch.nn.Module):
DOI: `arXiv preprint arXiv:2010.08895. DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_ <https://arxiv.org/abs/2010.08895>`_
""" """
def __init__(self,
input_numb_fields, def __init__(
output_numb_fields, self,
n_modes, input_numb_fields,
dimensions=3, output_numb_fields,
padding=8, n_modes,
padding_type="constant", dimensions=3,
inner_size=20, padding=8,
n_layers=2, padding_type="constant",
func=nn.Tanh, inner_size=20,
layers=None): n_layers=2,
func=nn.Tanh,
layers=None,
):
""" """
:param int input_numb_fields: Number of input fields. :param int input_numb_fields: Number of input fields.
:param int output_numb_fields: Number of output fields. :param int output_numb_fields: Number of output fields.
@@ -69,7 +72,8 @@ class FourierIntegralKernel(torch.nn.Module):
if not isinstance(n_modes, (list, tuple, int)): if not isinstance(n_modes, (list, tuple, int)):
raise ValueError( raise ValueError(
"n_modes must be a int or list or tuple of valid modes." "n_modes must be a int or list or tuple of valid modes."
" More information on the official documentation.") " More information on the official documentation."
)
# assign padding # assign padding
self._padding = padding self._padding = padding
@@ -82,9 +86,7 @@ class FourierIntegralKernel(torch.nn.Module):
elif dimensions == 3: elif dimensions == 3:
fourier_layer = FourierBlock3D fourier_layer = FourierBlock3D
else: else:
raise NotImplementedError( raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
"FNO implemented only for 1D/2D/3D data."
)
# Here we build the FNO kernels by stacking Fourier Blocks # Here we build the FNO kernels by stacking Fourier Blocks
@@ -96,7 +98,8 @@ class FourierIntegralKernel(torch.nn.Module):
if isinstance(func, list): if isinstance(func, list):
if len(layers) != len(func): if len(layers) != len(func):
raise RuntimeError( raise RuntimeError(
'Uncosistent number of layers and functions.') "Uncosistent number of layers and functions."
)
_functions = func _functions = func
else: else:
_functions = [func for _ in range(len(layers) - 1)] _functions = [func for _ in range(len(layers) - 1)]
@@ -104,10 +107,12 @@ class FourierIntegralKernel(torch.nn.Module):
# 3. Assign modes functions for each FNO layer # 3. Assign modes functions for each FNO layer
if isinstance(n_modes, list): if isinstance(n_modes, list):
if all(isinstance(i, list) if all(isinstance(i, list) for i in n_modes) and len(layers) != len(
for i in n_modes) and len(layers) != len(n_modes): n_modes
):
raise RuntimeError( raise RuntimeError(
"Uncosistent number of layers and functions.") "Uncosistent number of layers and functions."
)
elif all(isinstance(i, int) for i in n_modes): elif all(isinstance(i, int) for i in n_modes):
n_modes = [n_modes] * len(layers) n_modes = [n_modes] * len(layers)
else: else:
@@ -118,10 +123,13 @@ class FourierIntegralKernel(torch.nn.Module):
tmp_layers = [input_numb_fields] + layers + [output_numb_fields] tmp_layers = [input_numb_fields] + layers + [output_numb_fields]
for i in range(len(layers)): for i in range(len(layers)):
_layers.append( _layers.append(
fourier_layer(input_numb_fields=tmp_layers[i], fourier_layer(
output_numb_fields=tmp_layers[i + 1], input_numb_fields=tmp_layers[i],
n_modes=n_modes[i], output_numb_fields=tmp_layers[i + 1],
activation=_functions[i])) n_modes=n_modes[i],
activation=_functions[i],
)
)
self._layers = nn.Sequential(*_layers) self._layers = nn.Sequential(*_layers)
# 5. Padding values for spectral conv # 5. Padding values for spectral conv
@@ -150,10 +158,10 @@ class FourierIntegralKernel(torch.nn.Module):
:return: The output tensor obtained from the kernels convolution. :return: The output tensor obtained from the kernels convolution.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
if isinstance(x, LabelTensor): #TODO remove when Network is fixed if isinstance(x, LabelTensor): # TODO remove when Network is fixed
warnings.warn( warnings.warn(
'LabelTensor passed as input is not allowed,' "LabelTensor passed as input is not allowed,"
' casting LabelTensor to Torch.Tensor' " casting LabelTensor to Torch.Tensor"
) )
x = x.as_subclass(torch.Tensor) x = x.as_subclass(torch.Tensor)
# permuting the input [batch, channels, x, y, ...] # permuting the input [batch, channels, x, y, ...]
@@ -196,17 +204,20 @@ class FNO(KernelNeuralOperator):
DOI: `arXiv preprint arXiv:2010.08895. DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_ <https://arxiv.org/abs/2010.08895>`_
""" """
def __init__(self,
lifting_net, def __init__(
projecting_net, self,
n_modes, lifting_net,
dimensions=3, projecting_net,
padding=8, n_modes,
padding_type="constant", dimensions=3,
inner_size=20, padding=8,
n_layers=2, padding_type="constant",
func=nn.Tanh, inner_size=20,
layers=None): n_layers=2,
func=nn.Tanh,
layers=None,
):
""" """
:param torch.nn.Module lifting_net: The neural network for lifting :param torch.nn.Module lifting_net: The neural network for lifting
the input. the input.
@@ -222,21 +233,24 @@ class FNO(KernelNeuralOperator):
:param list[int] layers: List of layer sizes, defaults to None. :param list[int] layers: List of layer sizes, defaults to None.
""" """
lifting_operator_out = lifting_net( lifting_operator_out = lifting_net(
torch.rand(size=next(lifting_net.parameters()).size())).shape[-1] torch.rand(size=next(lifting_net.parameters()).size())
super().__init__(lifting_operator=lifting_net, ).shape[-1]
projection_operator=projecting_net, super().__init__(
integral_kernels=FourierIntegralKernel( lifting_operator=lifting_net,
input_numb_fields=lifting_operator_out, projection_operator=projecting_net,
output_numb_fields=next( integral_kernels=FourierIntegralKernel(
projecting_net.parameters()).size(), input_numb_fields=lifting_operator_out,
n_modes=n_modes, output_numb_fields=next(projecting_net.parameters()).size(),
dimensions=dimensions, n_modes=n_modes,
padding=padding, dimensions=dimensions,
padding_type=padding_type, padding=padding,
inner_size=inner_size, padding_type=padding_type,
n_layers=n_layers, inner_size=inner_size,
func=func, n_layers=n_layers,
layers=layers)) func=func,
layers=layers,
),
)
def forward(self, x): def forward(self, x):
""" """