tut10
This commit is contained in:
committed by
Nicola Demo
parent
ada9643c11
commit
98f7694d6f
@@ -1,7 +1,7 @@
|
||||
"""Module Averaging Neural Operator."""
|
||||
|
||||
import torch
|
||||
from torch import nn, concatenate
|
||||
from . import FeedForward
|
||||
from .layers import AVNOBlock
|
||||
from .base_no import KernelNeuralOperator
|
||||
from pina.utils import check_consistency
|
||||
@@ -28,65 +28,61 @@ class AveragingNeuralOperator(KernelNeuralOperator):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_fields,
|
||||
output_numb_fields,
|
||||
lifting_net,
|
||||
projecting_net,
|
||||
field_indices,
|
||||
coordinates_indices,
|
||||
dimension=3,
|
||||
inner_size=100,
|
||||
n_layers=4,
|
||||
func=nn.GELU,
|
||||
):
|
||||
"""
|
||||
:param int input_numb_fields: The number of input components
|
||||
of the model.
|
||||
:param int output_numb_fields: The number of output components
|
||||
of the model.
|
||||
:param int dimension: the dimension of the domain of the functions.
|
||||
:param int inner_size: number of neurons in the hidden layer(s).
|
||||
Defaults to 100.
|
||||
:param int n_layers: number of hidden layers. Default is 4.
|
||||
:param func: the activation function to use. Default to nn.GELU.
|
||||
:param torch.nn.Module lifting_net: The neural network for lifting
|
||||
the input. It must take as input the input field and the coordinates
|
||||
at which the input field is avaluated. The output of the lifting
|
||||
net is chosen as embedding dimension of the problem
|
||||
:param torch.nn.Module projecting_net: The neural network for
|
||||
projecting the output. It must take as input the embedding dimension
|
||||
(output of the ``lifting_net``) plus the dimension
|
||||
of the coordinates.
|
||||
:param list[str] field_indices: the label of the fields
|
||||
in the input tensor.
|
||||
:param list[str] coordinates_indices: the label of the
|
||||
coordinates in the input tensor.
|
||||
:param int n_layers: number of hidden layers. Default is 4.
|
||||
:param torch.nn.Module func: the activation function to use,
|
||||
default to torch.nn.GELU.
|
||||
"""
|
||||
|
||||
# check consistency
|
||||
check_consistency(input_numb_fields, int)
|
||||
check_consistency(output_numb_fields, int)
|
||||
check_consistency(field_indices, str)
|
||||
check_consistency(coordinates_indices, str)
|
||||
check_consistency(dimension, int)
|
||||
check_consistency(inner_size, int)
|
||||
check_consistency(n_layers, int)
|
||||
check_consistency(func, nn.Module, subclass=True)
|
||||
|
||||
# check hidden dimensions match
|
||||
input_lifting_net = next(lifting_net.parameters()).size()[-1]
|
||||
output_lifting_net = lifting_net(
|
||||
torch.rand(size=next(lifting_net.parameters()).size())
|
||||
).shape[-1]
|
||||
projecting_net_input=next(projecting_net.parameters()).size()[-1]
|
||||
|
||||
if len(field_indices)+len(coordinates_indices) != input_lifting_net:
|
||||
raise ValueError('The lifting_net must take as input the '
|
||||
'coordinates vector and the field vector.')
|
||||
|
||||
if output_lifting_net+len(coordinates_indices) != projecting_net_input:
|
||||
raise ValueError('The projecting_net input must be equal to'
|
||||
'the embedding dimension (which is the output) '
|
||||
'of the lifting_net plus the dimension of the '
|
||||
'coordinates, i.e. len(coordinates_indices).')
|
||||
|
||||
# assign
|
||||
self.input_numb_fields = input_numb_fields
|
||||
self.output_numb_fields = output_numb_fields
|
||||
self.dimension = dimension
|
||||
self.coordinates_indices = coordinates_indices
|
||||
self.field_indices = field_indices
|
||||
integral_net = nn.Sequential(
|
||||
*[AVNOBlock(inner_size, func) for _ in range(n_layers)]
|
||||
*[AVNOBlock(output_lifting_net, func) for _ in range(n_layers)]
|
||||
)
|
||||
lifting_net = FeedForward(
|
||||
dimension + input_numb_fields,
|
||||
inner_size,
|
||||
inner_size,
|
||||
n_layers,
|
||||
func,
|
||||
)
|
||||
projection_net = FeedForward(
|
||||
inner_size + dimension,
|
||||
output_numb_fields,
|
||||
inner_size,
|
||||
n_layers,
|
||||
func,
|
||||
)
|
||||
super().__init__(lifting_net, integral_net, projection_net)
|
||||
super().__init__(lifting_net, integral_net, projecting_net)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
@@ -106,10 +102,10 @@ class AveragingNeuralOperator(KernelNeuralOperator):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
points_tmp = x.extract(self.coordinates_indices)
|
||||
features_tmp = x.extract(self.field_indices)
|
||||
new_batch = concatenate((features_tmp, points_tmp), dim=2)
|
||||
new_batch = x.extract(self.field_indices)
|
||||
new_batch = concatenate((new_batch, points_tmp), dim=2)
|
||||
new_batch = self._lifting_operator(new_batch)
|
||||
new_batch = self._integral_kernels(new_batch)
|
||||
new_batch = concatenate((new_batch, points_tmp), dim=2)
|
||||
new_batch = self._projection_operator(new_batch)
|
||||
return new_batch
|
||||
return new_batch
|
||||
Reference in New Issue
Block a user