🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-03-05 11:31:14 +00:00
committed by Nicola Demo
parent 43f69242ab
commit 3d72205380
2 changed files with 25 additions and 14 deletions

View File

@@ -9,7 +9,7 @@ from pina.utils import check_consistency
class AveragingNeuralOperator(KernelNeuralOperator): class AveragingNeuralOperator(KernelNeuralOperator):
""" """
Implementation of Averaging Neural Operator. Implementation of Averaging Neural Operator.
Averaging Neural Operator is a general architecture for Averaging Neural Operator is a general architecture for
learning Operators. Unlike traditional machine learning methods learning Operators. Unlike traditional machine learning methods
@@ -38,19 +38,19 @@ class AveragingNeuralOperator(KernelNeuralOperator):
func=nn.GELU, func=nn.GELU,
): ):
""" """
:param int input_numb_fields: The number of input components :param int input_numb_fields: The number of input components
of the model. of the model.
:param int output_numb_fields: The number of output components :param int output_numb_fields: The number of output components
of the model. of the model.
:param int dimension: the dimension of the domain of the functions. :param int dimension: the dimension of the domain of the functions.
:param int inner_size: number of neurons in the hidden layer(s). :param int inner_size: number of neurons in the hidden layer(s).
Defaults to 100. Defaults to 100.
:param int n_layers: number of hidden layers. Default is 4. :param int n_layers: number of hidden layers. Default is 4.
:param func: the activation function to use. Default to nn.GELU. :param func: the activation function to use. Default to nn.GELU.
:param list[str] field_indices: the label of the fields :param list[str] field_indices: the label of the fields
in the input tensor. in the input tensor.
:param list[str] coordinates_indices: the label of the :param list[str] coordinates_indices: the label of the
coordinates in the input tensor. coordinates in the input tensor.
""" """
# check consistency # check consistency
@@ -70,11 +70,22 @@ class AveragingNeuralOperator(KernelNeuralOperator):
self.coordinates_indices = coordinates_indices self.coordinates_indices = coordinates_indices
self.field_indices = field_indices self.field_indices = field_indices
integral_net = nn.Sequential( integral_net = nn.Sequential(
*[AVNOBlock(inner_size, func) for _ in range(n_layers)]) *[AVNOBlock(inner_size, func) for _ in range(n_layers)]
lifting_net = FeedForward(dimension + input_numb_fields, inner_size, )
inner_size, n_layers, func) lifting_net = FeedForward(
projection_net = FeedForward(inner_size + dimension, output_numb_fields, dimension + input_numb_fields,
inner_size, n_layers, func) inner_size,
inner_size,
n_layers,
func,
)
projection_net = FeedForward(
inner_size + dimension,
output_numb_fields,
inner_size,
n_layers,
func,
)
super().__init__(lifting_net, integral_net, projection_net) super().__init__(lifting_net, integral_net, projection_net)
def forward(self, x): def forward(self, x):

View File

@@ -27,7 +27,7 @@ class AVNOBlock(nn.Module):
.. seealso:: .. seealso::
**Original reference**: Lanthaler S. Li, Z., Kovachki, **Original reference**: Lanthaler S. Li, Z., Kovachki,
Stuart, A. (2020). *The Nonlocal Neural Operator: Universal Stuart, A. (2020). *The Nonlocal Neural Operator: Universal
Approximation*. Approximation*.
DOI: `arXiv preprint arXiv:2304.13221. DOI: `arXiv preprint arXiv:2304.13221.