This commit is contained in:
Dario Coscia
2024-03-14 19:16:48 +01:00
committed by Nicola Demo
parent ada9643c11
commit 98f7694d6f
11 changed files with 1171 additions and 78 deletions

View File

@@ -1,7 +1,7 @@
"""Module Averaging Neural Operator."""
import torch
from torch import nn, concatenate
from . import FeedForward
from .layers import AVNOBlock
from .base_no import KernelNeuralOperator
from pina.utils import check_consistency
@@ -28,65 +28,61 @@ class AveragingNeuralOperator(KernelNeuralOperator):
def __init__(
self,
input_numb_fields,
output_numb_fields,
lifting_net,
projecting_net,
field_indices,
coordinates_indices,
dimension=3,
inner_size=100,
n_layers=4,
func=nn.GELU,
):
"""
:param int input_numb_fields: The number of input components
of the model.
:param int output_numb_fields: The number of output components
of the model.
:param int dimension: the dimension of the domain of the functions.
:param int inner_size: number of neurons in the hidden layer(s).
Defaults to 100.
:param int n_layers: number of hidden layers. Default is 4.
:param func: the activation function to use. Default to nn.GELU.
:param torch.nn.Module lifting_net: The neural network for lifting
the input. It must take as input the input field and the coordinates
at which the input field is avaluated. The output of the lifting
net is chosen as embedding dimension of the problem
:param torch.nn.Module projecting_net: The neural network for
projecting the output. It must take as input the embedding dimension
(output of the ``lifting_net``) plus the dimension
of the coordinates.
:param list[str] field_indices: the label of the fields
in the input tensor.
:param list[str] coordinates_indices: the label of the
coordinates in the input tensor.
:param int n_layers: number of hidden layers. Default is 4.
:param torch.nn.Module func: the activation function to use,
default to torch.nn.GELU.
"""
# check consistency
check_consistency(input_numb_fields, int)
check_consistency(output_numb_fields, int)
check_consistency(field_indices, str)
check_consistency(coordinates_indices, str)
check_consistency(dimension, int)
check_consistency(inner_size, int)
check_consistency(n_layers, int)
check_consistency(func, nn.Module, subclass=True)
# check hidden dimensions match
input_lifting_net = next(lifting_net.parameters()).size()[-1]
output_lifting_net = lifting_net(
torch.rand(size=next(lifting_net.parameters()).size())
).shape[-1]
projecting_net_input=next(projecting_net.parameters()).size()[-1]
if len(field_indices)+len(coordinates_indices) != input_lifting_net:
raise ValueError('The lifting_net must take as input the '
'coordinates vector and the field vector.')
if output_lifting_net+len(coordinates_indices) != projecting_net_input:
raise ValueError('The projecting_net input must be equal to'
'the embedding dimension (which is the output) '
'of the lifting_net plus the dimension of the '
'coordinates, i.e. len(coordinates_indices).')
# assign
self.input_numb_fields = input_numb_fields
self.output_numb_fields = output_numb_fields
self.dimension = dimension
self.coordinates_indices = coordinates_indices
self.field_indices = field_indices
integral_net = nn.Sequential(
*[AVNOBlock(inner_size, func) for _ in range(n_layers)]
*[AVNOBlock(output_lifting_net, func) for _ in range(n_layers)]
)
lifting_net = FeedForward(
dimension + input_numb_fields,
inner_size,
inner_size,
n_layers,
func,
)
projection_net = FeedForward(
inner_size + dimension,
output_numb_fields,
inner_size,
n_layers,
func,
)
super().__init__(lifting_net, integral_net, projection_net)
super().__init__(lifting_net, integral_net, projecting_net)
def forward(self, x):
r"""
@@ -106,10 +102,10 @@ class AveragingNeuralOperator(KernelNeuralOperator):
:rtype: torch.Tensor
"""
points_tmp = x.extract(self.coordinates_indices)
features_tmp = x.extract(self.field_indices)
new_batch = concatenate((features_tmp, points_tmp), dim=2)
new_batch = x.extract(self.field_indices)
new_batch = concatenate((new_batch, points_tmp), dim=2)
new_batch = self._lifting_operator(new_batch)
new_batch = self._integral_kernels(new_batch)
new_batch = concatenate((new_batch, points_tmp), dim=2)
new_batch = self._projection_operator(new_batch)
return new_batch
return new_batch