Fix Codacy Warnings (#477)

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-10 15:38:45 +01:00
committed by Nicola Demo
parent e3790e049a
commit 4177bfbb50
157 changed files with 3473 additions and 3839 deletions

View File

@@ -1,3 +1,7 @@
"""
Module containing the neural network models.
"""
__all__ = [
"FeedForward",
"ResidualFeedForward",

View File

@@ -1,8 +1,8 @@
"""Module Averaging Neural Operator."""
import torch
from torch import nn, cat
from .block import AVNOBlock
from torch import nn
from .block.average_neural_operator_block import AVNOBlock
from .kernel_neural_operator import KernelNeuralOperator
from ..utils import check_consistency
@@ -110,9 +110,9 @@ class AveragingNeuralOperator(KernelNeuralOperator):
"""
points_tmp = x.extract(self.coordinates_indices)
new_batch = x.extract(self.field_indices)
new_batch = cat((new_batch, points_tmp), dim=-1)
new_batch = torch.cat((new_batch, points_tmp), dim=-1)
new_batch = self._lifting_operator(new_batch)
new_batch = self._integral_kernels(new_batch)
new_batch = cat((new_batch, points_tmp), dim=-1)
new_batch = torch.cat((new_batch, points_tmp), dim=-1)
new_batch = self._projection_operator(new_batch)
return new_batch

View File

@@ -1,3 +1,7 @@
"""
Module containing the building blocks for models.
"""
__all__ = [
"ContinuousConvBlock",
"ResidualBlock",

View File

@@ -1,6 +1,7 @@
"""Module for Averaging Neural Operator Layer class."""
from torch import nn, mean
import torch
from torch import nn
from ...utils import check_consistency
@@ -64,4 +65,4 @@ class AVNOBlock(nn.Module):
:return: The output tensor obtained from Average Neural Operator Block.
:rtype: torch.Tensor
"""
return self._func(self._nn(x) + mean(x, dim=1, keepdim=True))
return self._func(self._nn(x) + torch.mean(x, dim=1, keepdim=True))

View File

@@ -75,34 +75,29 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
"""
super().__init__()
if isinstance(input_numb_field, int):
self._input_numb_field = input_numb_field
else:
if not isinstance(input_numb_field, int):
raise ValueError("input_numb_field must be int.")
self._input_numb_field = input_numb_field
if isinstance(output_numb_field, int):
self._output_numb_field = output_numb_field
else:
if not isinstance(output_numb_field, int):
raise ValueError("input_numb_field must be int.")
self._output_numb_field = output_numb_field
if isinstance(filter_dim, (tuple, list)):
vect = filter_dim
else:
if not isinstance(filter_dim, (tuple, list)):
raise ValueError("filter_dim must be tuple or list.")
vect = filter_dim
vect = torch.tensor(vect)
self.register_buffer("_dim", vect, persistent=False)
if isinstance(stride, dict):
self._stride = Stride(stride)
else:
if not isinstance(stride, dict):
raise ValueError("stride must be dictionary.")
self._stride = Stride(stride)
self._net = model
if isinstance(optimize, bool):
self._optimize = optimize
else:
if not isinstance(optimize, bool):
raise ValueError("optimize must be bool.")
self._optimize = optimize
# choosing how to initialize based on optimization
if self._optimize:
@@ -119,13 +114,18 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
if no_overlap:
raise NotImplementedError
self.transpose = self.transpose_no_overlap
else:
self.transpose = self.transpose_overlap
self.transpose = self.transpose_overlap
class DefaultKernel(torch.nn.Module):
"""
TODO
"""
def __init__(self, input_dim, output_dim):
"""
TODO
"""
super().__init__()
assert isinstance(input_dim, int)
assert isinstance(output_dim, int)
@@ -138,44 +138,66 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
)
def forward(self, x):
"""
TODO
"""
return self._model(x)
@property
def net(self):
"""
TODO
"""
return self._net
@property
def stride(self):
"""
TODO
"""
return self._stride
@property
def filter_dim(self):
"""
TODO
"""
return self._dim
@property
def input_numb_field(self):
"""
TODO
"""
return self._input_numb_field
@property
def output_numb_field(self):
"""
TODO
"""
return self._output_numb_field
@property
@abstractmethod
def forward(self, X):
pass
"""
TODO
"""
@property
@abstractmethod
def transpose_overlap(self, X):
pass
"""
TODO
"""
@property
@abstractmethod
def transpose_no_overlap(self, X):
pass
"""
TODO
"""
@property
@abstractmethod
def _initialize_convolution(self, X, type):
pass
def _initialize_convolution(self, X, type_):
"""
TODO
"""

View File

@@ -1,9 +1,9 @@
"""Module for Continuous Convolution class"""
import torch
from .convolution import BaseContinuousConv
from .utils_convolution import check_point, map_points_
from .integral import Integral
import torch
class ContinuousConvBlock(BaseContinuousConv):
@@ -27,8 +27,9 @@ class ContinuousConvBlock(BaseContinuousConv):
.. seealso::
**Original reference**: Coscia, D., Meneghetti, L., Demo, N. et al.
*A continuous convolutional trainable filter for modelling unstructured data*.
Comput Mech 72, 253265 (2023). DOI `<https://doi.org/10.1007/s00466-023-02291-1>`_
*A continuous convolutional trainable filter for modelling
unstructured data*. Comput Mech 72, 253265 (2023).
DOI `<https://doi.org/10.1007/s00466-023-02291-1>`_
"""
@@ -45,7 +46,8 @@ class ContinuousConvBlock(BaseContinuousConv):
"""
:param input_numb_field: Number of fields :math:`N_{in}` in the input.
:type input_numb_field: int
:param output_numb_field: Number of fields :math:`N_{out}` in the output.
:param output_numb_field: Number of fields :math:`N_{out}` in the
output.
:type output_numb_field: int
:param filter_dim: Dimension of the filter.
:type filter_dim: tuple(int) | list(int)
@@ -134,6 +136,11 @@ class ContinuousConvBlock(BaseContinuousConv):
# stride for continuous convolution overridden
self._stride = self._stride._stride_discrete
# Define variables
self._index = None
self._grid = None
self._grid_transpose = None
def _spawn_networks(self, model):
"""
Private method to create a collection of kernels
@@ -152,7 +159,7 @@ class ContinuousConvBlock(BaseContinuousConv):
else:
if not isinstance(model, object):
raise ValueError(
"Expected a python class inheriting" " from torch.nn.Module"
"Expected a python class inheriting from torch.nn.Module"
)
for _ in range(self._input_numb_field * self._output_numb_field):
@@ -271,7 +278,7 @@ class ContinuousConvBlock(BaseContinuousConv):
# save on tmp
self._grid_transpose = tmp
def _make_grid(self, X, type):
def _make_grid(self, X, type_):
"""
Private method to create convolution grid.
@@ -283,14 +290,15 @@ class ContinuousConvBlock(BaseContinuousConv):
"""
# choose the type of convolution
if type == "forward":
return self._make_grid_forward(X)
elif type == "inverse":
if type_ == "forward":
self._make_grid_forward(X)
return
if type_ == "inverse":
self._make_grid_transpose(X)
else:
raise TypeError
return
raise TypeError
def _initialize_convolution(self, X, type="forward"):
def _initialize_convolution(self, X, type_="forward"):
"""
Private method to intialize the convolution.
The convolution is initialized by setting a grid and
@@ -304,7 +312,7 @@ class ContinuousConvBlock(BaseContinuousConv):
"""
# variable for the convolution
self._make_grid(X, type)
self._make_grid(X, type_)
# calculate the index
self._find_index(X)
@@ -321,7 +329,7 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type="forward")
self._choose_initialization(X, type_="forward")
else: # we always initialize on testing
self._initialize_convolution(X, "forward")
@@ -383,12 +391,14 @@ class ContinuousConvBlock(BaseContinuousConv):
:type integral: torch.tensor
:param X: Input data. Expect tensor of shape
:math:`[B, N_{in}, M, D]` where :math:`B` is the batch_size,
:math`N_{in}`is the number of input fields, :math:`M` the number of points
:math`N_{in}`is the number of input fields, :math:`M` the number of
points
in the mesh, :math:`D` the dimension of the problem.
:type X: torch.Tensor
:return: Feed forward transpose convolution. Tensor of shape
:math:`[B, N_{out}, M, D]` where :math:`B` is the batch_size,
:math`N_{out}`is the number of input fields, :math:`M` the number of points
:math`N_{out}`is the number of input fields, :math:`M` the number of
points
in the mesh, :math:`D` the dimension of the problem.
:rtype: torch.Tensor
@@ -399,7 +409,7 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type="inverse")
self._choose_initialization(X, type_="inverse")
else: # we always initialize on testing
self._initialize_convolution(X, "inverse")
@@ -466,12 +476,14 @@ class ContinuousConvBlock(BaseContinuousConv):
:type integral: torch.tensor
:param X: Input data. Expect tensor of shape
:math:`[B, N_{in}, M, D]` where :math:`B` is the batch_size,
:math`N_{in}`is the number of input fields, :math:`M` the number of points
:math`N_{in}`is the number of input fields, :math:`M` the number of
points
in the mesh, :math:`D` the dimension of the problem.
:type X: torch.Tensor
:return: Feed forward transpose convolution. Tensor of shape
:math:`[B, N_{out}, M, D]` where :math:`B` is the batch_size,
:math`N_{out}`is the number of input fields, :math:`M` the number of points
:math`N_{out}`is the number of input fields, :math:`M` the number of
points
in the mesh, :math:`D` the dimension of the problem.
:rtype: torch.Tensor
@@ -481,7 +493,7 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize convolution
if self.training: # we choose what to do based on optimization
self._choose_initialization(X, type="inverse")
self._choose_initialization(X, type_="inverse")
else: # we always initialize on testing
self._initialize_convolution(X, "inverse")
@@ -491,7 +503,7 @@ class ContinuousConvBlock(BaseContinuousConv):
conv_transposed = self._grid_transpose.clone().detach()
# list to iterate for calculating nn output
tmp = [i for i in range(self._output_numb_field)]
tmp = list(range(self._output_numb_field))
iterate_conv = [
item for item in tmp for _ in range(self._input_numb_field)
]

View File

@@ -2,7 +2,6 @@
import torch
from pina.utils import check_consistency
from typing import Union, Sequence
class PeriodicBoundaryEmbedding(torch.nn.Module):
@@ -18,8 +17,9 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
u(\mathbf{x}) = u(\mathbf{x} + n \mathbf{L})\;\;
\forall n\in\mathbb{N}.
The :meth:`PeriodicBoundaryEmbedding` augments the input such that the periodic conditons
is guarantee. The input is augmented by the following formula:
The :meth:`PeriodicBoundaryEmbedding` augments the input such that the
periodic conditonsis guarantee. The input is augmented by the following
formula:
.. math::
\mathbf{x} \rightarrow \tilde{\mathbf{x}} = \left[1,
@@ -135,13 +135,13 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
if isinstance(indeces[0], str):
try:
return x.extract(indeces)
except AttributeError:
except AttributeError as e:
raise RuntimeError(
"Not possible to extract input variables from tensor."
" Ensure that the passed tensor is a LabelTensor or"
" pass list of integers to extract variables. For"
" more information refer to warning in the documentation."
)
) from e
elif isinstance(indeces[0], int):
return x[..., indeces]
else:
@@ -159,11 +159,14 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
class FourierFeatureEmbedding(torch.nn.Module):
"""
Fourier Feature Embedding class for encoding input features
using random Fourier features.
"""
def __init__(self, input_dimension, output_dimension, sigma):
r"""
Fourier Feature Embedding class for encoding input features
using random Fourier features.This class applies a Fourier
transformation to the input features,
This class applies a Fourier transformation to the input features,
which can help in learning high-frequency variations in data.
If multiple sigma are provided, the class
supports multiscale feature embedding, creating embeddings for

View File

@@ -1,8 +1,12 @@
"""
Module for Fourier Block implementation.
"""
import torch
import torch.nn as nn
from torch import nn
from ...utils import check_consistency
from . import (
from .spectral import (
SpectralConvBlock1D,
SpectralConvBlock2D,
SpectralConvBlock3D,
@@ -17,9 +21,9 @@ class FourierBlock1D(nn.Module):
.. seealso::
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B.,
Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). *Fourier neural operator for
parametric partial differential equations*.
**Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K.,
Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). *Fourier
neural operator for parametric partial differential equations*.
DOI: `arXiv preprint arXiv:2010.08895.
<https://arxiv.org/abs/2010.08895>`_
@@ -32,24 +36,26 @@ class FourierBlock1D(nn.Module):
n_modes,
activation=torch.nn.Tanh,
):
super().__init__()
"""
PINA implementation of Fourier block one dimension. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, N]``
and returns an output of size ``[batch, output_numb_fields, N]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(N/2)+1``.
:param list | tuple n_modes: Number of modes to select for each
dimension. It must be at most equal to the ``floor(N/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
super().__init__()
# check type consistency
check_consistency(activation(), nn.Module)
@@ -109,13 +115,15 @@ class FourierBlock2D(nn.Module):
input in the physical space. Finally an activation function is
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
The block expects an input of size
``[batch, input_numb_fields, Nx, Ny]`` and returns an output of size
``[batch, output_numb_fields, Nx, Ny]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``.
:param list | tuple n_modes: Number of modes to select for each
dimension. It must be at most equal to the ``floor(Nx/2)+1``
and ``floor(Ny/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
super().__init__()
@@ -172,21 +180,22 @@ class FourierBlock3D(nn.Module):
activation=torch.nn.Tanh,
):
"""
PINA implementation of Fourier block three dimensions. The module computes
the spectral convolution of the input with a linear kernel in the
fourier space, and then it maps the input back to the physical
PINA implementation of Fourier block three dimensions. The module
computes the spectral convolution of the input with a linear kernel in
the fourier space, and then it maps the input back to the physical
space. The output is then added to a Linear tranformation of the
input in the physical space. Finally an activation function is
applied to the output.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
The block expects an input of size
``[batch, input_numb_fields, Nx, Ny, Nz]`` and returns an output of size
``[batch, output_numb_fields, Nx, Ny, Nz]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1``
and ``floor(Nz/2)+1``.
:param list | tuple n_modes: Number of modes to select for each
dimension. It must be at most equal to the ``floor(Nx/2)+1``,
``floor(Ny/2)+1`` and ``floor(Nz/2)+1``.
:param torch.nn.Module activation: The activation function.
"""
super().__init__()

View File

@@ -1,10 +1,14 @@
"""
Module containing the Graph Integral Layer class.
"""
import torch
from torch_geometric.nn import MessagePassing
class GNOBlock(MessagePassing):
"""
TODO: Add documentation
Graph Neural Operator (GNO) Block using PyG MessagePassing.
"""
def __init__(
@@ -18,21 +22,21 @@ class GNOBlock(MessagePassing):
external_func=None,
):
"""
Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric.
Initialize the GNOBlock.
:param width: The width of the hidden representation of the nodes features
:type width: int
:param edges_features: The number of edge features.
:type edges_features: int
:param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features.
:type n_layers: int
:param width: Hidden dimension of node features.
:param edges_features: Number of edge features.
:param n_layers: Number of layers in edge transformation MLP.
"""
from pina.model import FeedForward
super(GNOBlock, self).__init__(aggr="mean")
from ...model.feed_forward import FeedForward
super().__init__(aggr="mean") # Uses PyG's default aggregation
self.width = width
if layers is None and inner_size is None:
inner_size = width
self.dense = FeedForward(
input_dimensions=edges_features,
output_dimensions=width**2,
@@ -41,48 +45,50 @@ class GNOBlock(MessagePassing):
inner_size=inner_size,
func=internal_func,
)
self.W = torch.nn.Linear(width, width)
self.func = external_func()
def message(self, x_j, edge_attr):
def message_and_aggregate(self, edge_index, x, edge_attr):
"""
This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class.
Combines message and aggregation.
:param x_j: The node features of the neighboring.
:type x_j: torch.Tensor
:param edge_attr: The edge features.
:type edge_attr: torch.Tensor
:return: The message passed between the nodes of the graph.
:rtype: torch.Tensor
:param edge_index: COO format edge indices.
:param x: Node feature matrix [num_nodes, width].
:param edge_attr: Edge features [num_edges, edge_dim].
:return: Aggregated messages.
"""
x = self.dense(edge_attr).view(-1, self.width, self.width)
return torch.einsum("bij,bj->bi", x, x_j)
# Edge features are transformed into a matrix of shape
# [num_edges, width, width]
x_ = self.dense(edge_attr).view(-1, self.width, self.width)
# Messages are computed as the product of the edge features
messages = torch.einsum("bij,bj->bi", x_, x[edge_index[0]])
# Aggregation is performed using the mean (set in the constructor)
return self.aggregate(messages, edge_index[1])
def edge_update(self, edge_attr):
"""
Updates edge features.
"""
return edge_attr
def update(self, aggr_out, x):
"""
This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class.
Updates node features.
:param aggr_out: The aggregated messages.
:type aggr_out: torch.Tensor
:param x: The node features.
:type x: torch.Tensor
:return: The updated node features.
:rtype: torch.Tensor
:param aggr_out: Aggregated messages.
:param x: Node feature matrix.
:return: Updated node features.
"""
aggr_out = aggr_out + self.W(x)
return aggr_out
return aggr_out + self.W(x)
def forward(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Integral Layer.
Forward pass of the GNOBlock.
:param x: Node features.
:type x: torch.Tensor
:param edge_index: Edge index.
:type edge_index: torch.Tensor
:param edge_index: Edge indices.
:param edge_attr: Edge features.
:type edge_attr: torch.Tensor
:return: Output of a single iteration over the Graph Integral Layer.
:rtype: torch.Tensor
:return: Updated node features.
"""
return self.func(self.propagate(edge_index, x=x, edge_attr=edge_attr))

View File

@@ -1,10 +1,18 @@
"""
Module for performing integral for continuous convolution
"""
import torch
class Integral(object):
class Integral:
"""
Integral class for continous convolution
"""
def __init__(self, param):
"""Integral class for continous convolution
"""
Initialize the integral class
:param param: type of continuous convolution
:type param: string

View File

@@ -2,8 +2,7 @@
import torch
from pina.utils import check_consistency
import pina.model as pm # avoid circular import
from ...utils import check_consistency
class LowRankBlock(torch.nn.Module):
@@ -78,9 +77,10 @@ class LowRankBlock(torch.nn.Module):
basis function network.
"""
super().__init__()
from ..feed_forward import FeedForward
# Assignment (check consistency inside FeedForward)
self._basis = pm.FeedForward(
self._basis = FeedForward(
input_dimensions=input_dimensions,
output_dimensions=2 * rank * embedding_dimenion,
inner_size=inner_size,

View File

@@ -1,10 +1,6 @@
"""Module for Base Continuous Convolution class."""
from abc import ABCMeta, abstractmethod
import torch
from .stride import Stride
from .utils_convolution import optimizing
import warnings
class PODBlock(torch.nn.Module):
@@ -15,7 +11,8 @@ class PODBlock(torch.nn.Module):
The layer is not trainable.
.. note::
All the POD modes are stored in memory, avoiding to recompute them when the rank changes but increasing the memory usage.
All the POD modes are stored in memory, avoiding to recompute them when
the rank changes but increasing the memory usage.
"""
def __init__(self, rank, scale_coefficients=True):
@@ -51,7 +48,8 @@ class PODBlock(torch.nn.Module):
@property
def basis(self):
"""
The POD basis. It is a matrix whose columns are the first `self.rank` POD modes.
The POD basis. It is a matrix whose columns are the first `self.rank`
POD modes.
:rtype: torch.Tensor
"""
@@ -69,7 +67,7 @@ class PODBlock(torch.nn.Module):
:rtype: dict
"""
if self._scaler is None:
return
return None
return {
"mean": self._scaler["mean"][: self.rank],
@@ -115,7 +113,8 @@ class PODBlock(torch.nn.Module):
def _fit_pod(self, X, randomized):
"""
Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`.
Private method that computes the POD basis of the given tensor and
stores it in the private member `_basis`.
:param torch.Tensor X: The tensor to be reduced.
"""

View File

@@ -1,5 +1,9 @@
"""
TODO: Add title.
"""
import torch
import torch.nn as nn
from torch import nn
from ...utils import check_consistency
@@ -35,7 +39,8 @@ class ResidualBlock(nn.Module):
(first block).
:param bool spectral_norm: Apply spectral normalization to feedforward
layers, defaults to False.
:param torch.nn.Module activation: Cctivation function after first block.
:param torch.nn.Module activation: Cctivation function after first
block.
"""
super().__init__()
@@ -81,19 +86,17 @@ class ResidualBlock(nn.Module):
return nn.utils.spectral_norm(x) if self._spectral_norm else x
import torch
import torch.nn as nn
class EnhancedLinear(torch.nn.Module):
"""
A wrapper class for enhancing a linear layer with activation and/or dropout.
:param layer: The linear layer to be enhanced.
:type layer: torch.nn.Module
:param activation: The activation function to be applied after the linear layer.
:param activation: The activation function to be applied after the linear
layer.
:type activation: torch.nn.Module
:param dropout: The dropout probability to be applied after the activation (if provided).
:param dropout: The dropout probability to be applied after the activation
(if provided).
:type dropout: float
:Example:
@@ -110,9 +113,11 @@ class EnhancedLinear(torch.nn.Module):
:param layer: The linear layer to be enhanced.
:type layer: torch.nn.Module
:param activation: The activation function to be applied after the linear layer.
:param activation: The activation function to be applied after the
linear layer.
:type activation: torch.nn.Module
:param dropout: The dropout probability to be applied after the activation (if provided).
:param dropout: The dropout probability to be applied after the
activation (if provided).
:type dropout: float
"""
super().__init__()

View File

@@ -1,7 +1,10 @@
"""
TODO: Add title.
"""
import torch
import torch.nn as nn
from torch import nn
from ...utils import check_consistency
import warnings
######## 1D Spectral Convolution ###########
@@ -13,7 +16,8 @@ class SpectralConvBlock1D(nn.Module):
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
"""
The module computes the spectral convolution of the input with a linear kernel in the
The module computes the spectral convolution of the input with a linear
kernel in the
fourier space, and then it maps the input back to the physical
space.
@@ -106,17 +110,20 @@ class SpectralConvBlock2D(nn.Module):
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
"""
The module computes the spectral convolution of the input with a linear kernel in the
The module computes the spectral convolution of the input with a linear
kernel in the
fourier space, and then it maps the input back to the physical
space.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny]``
The block expects an input of size
``[batch, input_numb_fields, Nx, Ny]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1`` and ``floor(Ny/2)+1``.
:param list | tuple n_modes: Number of modes to select for each
dimension. It must be at most equal to the ``floor(Nx/2)+1`` and
``floor(Ny/2)+1``.
"""
super().__init__()
@@ -234,18 +241,21 @@ class SpectralConvBlock3D(nn.Module):
def __init__(self, input_numb_fields, output_numb_fields, n_modes):
"""
The module computes the spectral convolution of the input with a linear kernel in the
The module computes the spectral convolution of the input with a
linear kernel in the
fourier space, and then it maps the input back to the physical
space.
The block expects an input of size ``[batch, input_numb_fields, Nx, Ny, Nz]``
and returns an output of size ``[batch, output_numb_fields, Nx, Ny, Nz]``.
The block expects an input of size
``[batch, input_numb_fields, Nx, Ny, Nz]``
and returns an output of size
``[batch, output_numb_fields, Nx, Ny, Nz]``.
:param int input_numb_fields: The number of channels for the input.
:param int output_numb_fields: The number of channels for the output.
:param list | tuple n_modes: Number of modes to select for each dimension.
It must be at most equal to the ``floor(Nx/2)+1``, ``floor(Ny/2)+1``
and ``floor(Nz/2)+1``.
:param list | tuple n_modes: Number of modes to select for each
dimension. It must be at most equal to the ``floor(Nx/2)+1``,
``floor(Ny/2)+1`` and ``floor(Nz/2)+1``.
"""
super().__init__()
@@ -347,7 +357,8 @@ class SpectralConvBlock3D(nn.Module):
``[batch, input_numb_fields, x, y, z]``.
:type x: torch.Tensor
:return: The output tensor obtained from the
spectral convolution of size ``[batch, output_numb_fields, x, y, z]``.
spectral convolution of size
``[batch, output_numb_fields, x, y, z]``.
:rtype: torch.Tensor
"""

View File

@@ -1,18 +1,25 @@
"""
TODO: Add description
"""
import torch
class Stride(object):
class Stride:
"""
TODO
"""
def __init__(self, dict):
def __init__(self, dict_):
"""Stride class for continous convolution
:param param: type of continuous convolution
:type param: string
"""
self._dict_stride = dict
self._dict_stride = dict_
self._stride_continuous = None
self._stride_discrete = self._create_stride_discrete(dict)
self._stride_discrete = self._create_stride_discrete(dict_)
def _create_stride_discrete(self, my_dict):
"""Creating the list for applying the filter
@@ -46,13 +53,13 @@ class Stride(object):
# checking
if not all([len(s) == len(domain) for s in my_dict.values()]):
if not all(len(s) == len(domain) for s in my_dict.values()):
raise IndexError("values in the dict must have all same length")
if not all(v >= 0 for v in domain):
raise ValueError("domain values must be greater than 0")
if not all(v == 1 or v == -1 or v == 0 for v in direction):
if not all(v in (0, -1, 1) for v in direction):
raise ValueError("direction must be either equal to 1, -1 or 0")
seq_jumps = [i for i, e in enumerate(jumps) if e == 0]

View File

@@ -1,7 +1,14 @@
"""
TODO
"""
import torch
def check_point(x, current_stride, dim):
"""
TODO
"""
max_stride = current_stride + dim
indeces = torch.logical_and(
x[..., :-1] < max_stride, x[..., :-1] >= current_stride
@@ -33,16 +40,18 @@ def optimizing(f):
def wrapper(*args, **kwargs):
if kwargs["type"] == "forward":
if kwargs["type_"] == "forward":
if not wrapper.has_run_inverse:
wrapper.has_run_inverse = True
return f(*args, **kwargs)
if kwargs["type"] == "inverse":
if kwargs["type_"] == "inverse":
if not wrapper.has_run:
wrapper.has_run = True
return f(*args, **kwargs)
return f(*args, **kwargs)
wrapper.has_run_inverse = False
wrapper.has_run = False

View File

@@ -1,9 +1,9 @@
"""Module for DeepONet model"""
import torch
import torch.nn as nn
from ..utils import check_consistency, is_function
from functools import partial
import torch
from torch import nn
from ..utils import check_consistency, is_function
class MIONet(torch.nn.Module):
@@ -12,8 +12,9 @@ class MIONet(torch.nn.Module):
MIONet is a general architecture for learning Operators defined
on the tensor product of Banach spaces. Unlike traditional machine
learning methods MIONet is designed to map entire functions to other functions.
It can be trained both with Physics Informed or Supervised learning strategies.
learning methods MIONet is designed to map entire functions to other
functions. It can be trained both with Physics Informed or Supervised
learning strategies.
.. seealso::
@@ -37,37 +38,45 @@ class MIONet(torch.nn.Module):
:param dict networks: The neural networks to use as
models. The ``dict`` takes as key a neural network, and
as value the list of indeces to extract from the input variable
in the forward pass of the neural network. If a list of ``int`` is passed,
the corresponding columns of the inner most entries are extracted.
If a list of ``str`` is passed the variables of the corresponding :py:obj:`pina.label_tensor.LabelTensor`
are extracted. The ``torch.nn.Module`` model has to take as input a
in the forward pass of the neural network. If a list of ``int``
is passed, the corresponding columns of the inner most entries are
extracted.
If a list of ``str`` is passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor`are extracted. The
``torch.nn.Module`` model has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
Default implementation consist of different branch nets and one trunk nets.
Default implementation consist of different branch nets and one
trunk nets.
:param str or Callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. Available aggregators include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max: ``max``.
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max:
``max``.
:param str or Callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. Available reductions include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max: ``max``.
:param bool or Callable scale: Scaling the final output before returning the
forward pass, default ``True``.
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max:
``max``.
:param bool or Callable scale: Scaling the final output before returning
the forward pass, default ``True``.
:param bool or Callable translation: Translating the final output before
returning the forward pass, default ``True``.
.. warning::
In the forward pass we do not check if the input is instance of
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`. A general rule is
that for a :py:obj:`pina.label_tensor.LabelTensor` input both list of integers and
list of strings can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
only a list of integers can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``.
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
A general rule is that for a :py:obj:`pina.label_tensor.LabelTensor`
input both list of integers and list of strings can be passed for
``input_indeces_branch_net``and ``input_indeces_trunk_net``.
Differently, for a :class:`torch.Tensor` only a list of integers
can be passed for ``input_indeces_branch_net``and
``input_indeces_trunk_net``.
:Example:
>>> branch_net1 = FeedForward(input_dimensons=1, output_dimensions=10)
>>> branch_net2 = FeedForward(input_dimensons=2, output_dimensions=10)
>>> branch_net1 = FeedForward(input_dimensons=1,
... output_dimensions=10)
>>> branch_net2 = FeedForward(input_dimensons=2,
... output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> networks = {branch_net1 : ['x'],
branch_net2 : ['x', 'y'],
@@ -125,7 +134,7 @@ class MIONet(torch.nn.Module):
if not all(map(lambda x: x == shapes[0], shapes)):
raise ValueError(
"The passed networks have not the same " "output dimension."
"The passed networks have not the same output dimension."
)
# assign trunk and branch net with their input indeces
@@ -163,7 +172,7 @@ class MIONet(torch.nn.Module):
}
def _init_aggregator(self, aggregator):
aggregator_funcs = DeepONet._symbol_functions(dim=2)
aggregator_funcs = self._symbol_functions(dim=2)
if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator]
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
@@ -175,7 +184,7 @@ class MIONet(torch.nn.Module):
self._aggregator_type = aggregator
def _init_reduction(self, reduction):
reduction_funcs = DeepONet._symbol_functions(dim=-1)
reduction_funcs = self._symbol_functions(dim=-1)
if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction]
elif isinstance(reduction, nn.Module) or is_function(reduction):
@@ -190,13 +199,13 @@ class MIONet(torch.nn.Module):
if isinstance(indeces[0], str):
try:
return x.extract(indeces)
except AttributeError:
except AttributeError as e:
raise RuntimeError(
"Not possible to extract input variables from tensor."
" Ensure that the passed tensor is a LabelTensor or"
" pass list of integers to extract variables. For"
" more information refer to warning in the documentation."
)
) from e
elif isinstance(indeces[0], int):
return x[..., indeces]
else:
@@ -209,7 +218,8 @@ class MIONet(torch.nn.Module):
"""
Defines the computation performed at every call.
:param LabelTensor or torch.Tensor x: The input tensor for the forward call.
:param LabelTensor or torch.Tensor x: The input tensor for the forward
call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor or torch.Tensor
"""
@@ -225,7 +235,7 @@ class MIONet(torch.nn.Module):
# reduce
output_ = self._reduction(aggregated)
if self._reduction_type in DeepONet._symbol_functions(dim=-1):
if self._reduction_type in self._symbol_functions(dim=-1):
output_ = output_.reshape(-1, 1)
# scale and translate
@@ -309,47 +319,55 @@ class DeepONet(MIONet):
):
"""
:param torch.nn.Module branch_net: The neural network to use as branch
model. It has to take as input a :py:obj:`pina.label_tensor.LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output has
to be the same of the ``trunk_net``.
model. It has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
The number of dimensions of the output has to be the same of the
``trunk_net``.
:param torch.nn.Module trunk_net: The neural network to use as trunk
model. It has to take as input a :py:obj:`pina.label_tensor.LabelTensor`
or :class:`torch.Tensor`. The number of dimensions of the output
has to be the same of the ``branch_net``.
model. It has to take as input a
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
The number of dimensions of the output has to be the same of the
``branch_net``.
:param list(int) or list(str) input_indeces_branch_net: List of indeces
to extract from the input variable in the forward pass for the
branch net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :py:obj:`pina.label_tensor.LabelTensor` are extracted.
branch net. If a list of ``int`` is passed, the corresponding
columns of the inner most entries are extracted. If a list of
``str`` is passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor` are extracted.
:param list(int) or list(str) input_indeces_trunk_net: List of indeces
to extract from the input variable in the forward pass for the
trunk net. If a list of ``int`` is passed, the corresponding columns
of the inner most entries are extracted. If a list of ``str`` is passed
the variables of the corresponding :py:obj:`pina.label_tensor.LabelTensor` are extracted.
of the inner most entries are extracted. If a list of ``str`` is
passed the variables of the corresponding
:py:obj:`pina.label_tensor.LabelTensor` are extracted.
:param str or Callable aggregator: Aggregator to be used to aggregate
partial results from the modules in `nets`. Partial results are
aggregated component-wise. Available aggregators include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max: ``max``.
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``,
max: ``max``.
:param str or Callable reduction: Reduction to be used to reduce
the aggregated result of the modules in `nets` to the desired output
dimension. Available reductions include
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``, max: ``max``.
:param bool or Callable scale: Scaling the final output before returning the
forward pass, default True.
sum: ``+``, product: ``*``, mean: ``mean``, min: ``min``,
max: ``max``.
:param bool or Callable scale: Scaling the final output before returning
the forward pass, default True.
:param bool or Callable translation: Translating the final output before
returning the forward pass, default True.
.. warning::
In the forward pass we do not check if the input is instance of
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`. A general rule is
that for a :py:obj:`pina.label_tensor.LabelTensor` input both list of integers and
list of strings can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``. Differently, for a :class:`torch.Tensor`
only a list of integers can be passed for ``input_indeces_branch_net``
and ``input_indeces_trunk_net``.
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
A general rule is that for a :py:obj:`pina.label_tensor.LabelTensor`
input both list of integers and list of strings can be passed for
``input_indeces_branch_net`` and ``input_indeces_trunk_net``.
Differently, for a :class:`torch.Tensor` only a list of integers can
be passed for ``input_indeces_branch_net`` and
``input_indeces_trunk_net``.
:Example:
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> branch_net = FeedForward(input_dimensons=1,
... output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> model = DeepONet(branch_net=branch_net,
... trunk_net=trunk_net,
@@ -395,7 +413,8 @@ class DeepONet(MIONet):
"""
Defines the computation performed at every call.
:param LabelTensor or torch.Tensor x: The input tensor for the forward call.
:param LabelTensor or torch.Tensor x: The input tensor for the forward
call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor or torch.Tensor
"""

View File

@@ -1,7 +1,7 @@
"""Module for FeedForward model"""
import torch
import torch.nn as nn
from torch import nn
from ..utils import check_consistency
from .block.residual import EnhancedLinear
@@ -13,10 +13,12 @@ class FeedForward(torch.nn.Module):
:param int input_dimensions: The number of input components of the model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none, and :math:`d` the ``input_dimensions``.
means any number of dimensions including none, and :math:`d` the
``input_dimensions``.
:param int output_dimensions: The number of output components of the model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none, and :math:`d` the ``output_dimensions``.
means any number of dimensions including none, and :math:`d` the
``output_dimensions``.
:param int inner_size: number of neurons in the hidden layer(s). Default is
20.
:param int n_layers: number of hidden layers. Default is 2.
@@ -24,9 +26,9 @@ class FeedForward(torch.nn.Module):
:class:`torch.nn.Module` is passed, this is used as activation function
after any layers, except the last one. If a list of Modules is passed,
they are used as activation functions at any layers, in order.
:param list(int) | tuple(int) layers: a list containing the number of neurons for
any hidden layers. If specified, the parameters ``n_layers`` e
``inner_size`` are not considered.
:param list(int) | tuple(int) layers: a list containing the number of
neurons for any hidden layers. If specified, the parameters ``n_layers``
and ``inner_size`` are not considered.
:param bool bias: If ``True`` the MLP will consider some bias.
"""
@@ -72,10 +74,10 @@ class FeedForward(torch.nn.Module):
raise RuntimeError("uncosistent number of layers and functions")
unique_list = []
for layer, func in zip(self.layers[:-1], self.functions):
for layer, func_ in zip(self.layers[:-1], self.functions):
unique_list.append(layer)
if func is not None:
unique_list.append(func())
if func_ is not None:
unique_list.append(func_())
unique_list.append(self.layers[-1])
self.model = nn.Sequential(*unique_list)
@@ -95,24 +97,27 @@ class FeedForward(torch.nn.Module):
class ResidualFeedForward(torch.nn.Module):
"""
The PINA implementation of feedforward network, also with skipped connection
and transformer network, as presented in **Understanding and mitigating gradient
pathologies in physics-informed neural networks**
and transformer network, as presented in **Understanding and mitigating
gradient pathologies in physics-informed neural networks**
.. seealso::
**Original reference**: Wang, Sifan, Yujun Teng, and Paris Perdikaris.
*Understanding and mitigating gradient flow pathologies in physics-informed
neural networks*. SIAM Journal on Scientific Computing 43.5 (2021): A3055-A3081.
*Understanding and mitigating gradient flow pathologies in
physics-informed neural networks*. SIAM Journal on Scientific Computing
43.5 (2021): A3055-A3081.
DOI: `10.1137/20M1318043
<https://epubs.siam.org/doi/abs/10.1137/20M1318043>`_
:param int input_dimensions: The number of input components of the model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none, and :math:`d` the ``input_dimensions``.
means any number of dimensions including none, and :math:`d` the
``input_dimensions``.
:param int output_dimensions: The number of output components of the model.
Expected tensor shape of the form :math:`(*, d)`, where *
means any number of dimensions including none, and :math:`d` the ``output_dimensions``.
means any number of dimensions including none, and :math:`d` the
``output_dimensions``.
:param int inner_size: number of neurons in the hidden layer(s). Default is
20.
:param int n_layers: number of hidden layers. Default is 2.
@@ -148,6 +153,63 @@ class ResidualFeedForward(torch.nn.Module):
check_consistency(func, torch.nn.Module, subclass=True)
check_consistency(bias, bool)
transformer_nets = self._check_transformer_nets(
transformer_nets, input_dimensions, inner_size
)
# assign variables
self.transformer_nets = nn.ModuleList(transformer_nets)
# build layers
layers = [inner_size] * n_layers
layers = layers.copy()
layers.insert(0, input_dimensions)
self.layers = []
for i in range(len(layers) - 1):
self.layers.append(nn.Linear(layers[i], layers[i + 1], bias=bias))
self.last_layer = nn.Linear(
layers[len(layers) - 1], output_dimensions, bias=bias
)
if isinstance(func, list):
self.functions = func()
else:
self.functions = [func() for _ in range(len(self.layers))]
if len(self.layers) != len(self.functions):
raise RuntimeError("uncosistent number of layers and functions")
unique_list = []
for layer, func_ in zip(self.layers, self.functions):
unique_list.append(EnhancedLinear(layer=layer, activation=func_))
self.inner_layers = torch.nn.Sequential(*unique_list)
def forward(self, x):
"""
Defines the computation performed at every call.
:param x: The tensor to apply the forward pass.
:type x: torch.Tensor
:return: the output computed by the model.
:rtype: torch.Tensor
"""
# enhance the input with transformer
input_ = []
for nets in self.transformer_nets:
input_.append(nets(x))
# skip connections pass
for layer in self.inner_layers.children():
x = layer(x)
x = (1.0 - x) * input_[0] + x * input_[1]
# last layer
return self.last_layer(x)
@staticmethod
def _check_transformer_nets(transformer_nets, input_dimensions, inner_size):
# check transformer nets
if transformer_nets is None:
transformer_nets = [
@@ -172,75 +234,25 @@ class ResidualFeedForward(torch.nn.Module):
for net in transformer_nets:
if not isinstance(net, nn.Module):
raise ValueError(
"transformer_nets needs to be a list of torch.nn.Module."
"transformer_nets needs to be a list of "
"torch.nn.Module."
)
x = torch.rand(10, input_dimensions)
try:
out = net(x)
except RuntimeError:
except RuntimeError as e:
raise ValueError(
"transformer network input incompatible with input_dimensions."
)
"transformer network input incompatible with "
"input_dimensions."
) from e
if out.shape[-1] != inner_size:
raise ValueError(
"transformer network output incompatible with inner_size."
"transformer network output incompatible with "
"inner_size."
)
else:
RuntimeError(
"Runtime error for transformer nets, check official documentation."
raise RuntimeError(
"Runtime error for transformer nets, check official "
"documentation."
)
# assign variables
self.input_dimension = input_dimensions
self.output_dimension = output_dimensions
self.transformer_nets = nn.ModuleList(transformer_nets)
# build layers
layers = [inner_size] * n_layers
tmp_layers = layers.copy()
tmp_layers.insert(0, self.input_dimension)
self.layers = []
for i in range(len(tmp_layers) - 1):
self.layers.append(
nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias)
)
self.last_layer = nn.Linear(
tmp_layers[len(tmp_layers) - 1], output_dimensions, bias=bias
)
if isinstance(func, list):
self.functions = func()
else:
self.functions = [func() for _ in range(len(self.layers))]
if len(self.layers) != len(self.functions):
raise RuntimeError("uncosistent number of layers and functions")
unique_list = []
for layer, func in zip(self.layers, self.functions):
unique_list.append(EnhancedLinear(layer=layer, activation=func))
self.inner_layers = torch.nn.Sequential(*unique_list)
def forward(self, x):
"""
Defines the computation performed at every call.
:param x: The tensor to apply the forward pass.
:type x: torch.Tensor
:return: the output computed by the model.
:rtype: torch.Tensor
"""
# enhance the input with transformer
input_ = []
for nets in self.transformer_nets:
input_.append(nets(x))
# skip connections pass
for layer in self.inner_layers.children():
x = layer(x)
x = (1.0 - x) * input_[0] + x * input_[1]
# last layer
return self.last_layer(x)
return transformer_nets

View File

@@ -2,10 +2,10 @@
Fourier Neural Operator Module.
"""
import torch
import torch.nn as nn
from ..label_tensor import LabelTensor
import warnings
import torch
from torch import nn
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from .block.fourier_block import FourierBlock1D, FourierBlock2D, FourierBlock3D
from .kernel_neural_operator import KernelNeuralOperator
@@ -57,36 +57,22 @@ class FourierIntegralKernel(torch.nn.Module):
super().__init__()
# check type consistency
check_consistency(dimensions, int)
check_consistency(padding, int)
check_consistency(padding_type, str)
check_consistency(inner_size, int)
check_consistency(n_layers, int)
check_consistency(func, nn.Module, subclass=True)
if layers is not None:
if isinstance(layers, (tuple, list)):
check_consistency(layers, int)
else:
raise ValueError("layers must be tuple or list of int.")
if not isinstance(n_modes, (list, tuple, int)):
raise ValueError(
"n_modes must be a int or list or tuple of valid modes."
" More information on the official documentation."
)
self._check_consistency(
dimensions,
padding,
padding_type,
inner_size,
n_layers,
func,
layers,
n_modes,
)
# assign padding
self._padding = padding
# initialize fourier layer for each dimension
if dimensions == 1:
fourier_layer = FourierBlock1D
elif dimensions == 2:
fourier_layer = FourierBlock2D
elif dimensions == 3:
fourier_layer = FourierBlock3D
else:
raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
fourier_layer = self._get_fourier_block(dimensions)
# Here we build the FNO kernels by stacking Fourier Blocks
@@ -113,24 +99,24 @@ class FourierIntegralKernel(torch.nn.Module):
raise RuntimeError(
"Uncosistent number of layers and functions."
)
elif all(isinstance(i, int) for i in n_modes):
if all(isinstance(i, int) for i in n_modes):
n_modes = [n_modes] * len(layers)
else:
n_modes = [n_modes] * len(layers)
# 4. Build the FNO network
_layers = []
tmp_layers = [input_numb_fields] + layers + [output_numb_fields]
for i in range(len(layers)):
_layers.append(
self._layers = nn.Sequential(
*[
fourier_layer(
input_numb_fields=tmp_layers[i],
output_numb_fields=tmp_layers[i + 1],
n_modes=n_modes[i],
activation=_functions[i],
)
)
self._layers = nn.Sequential(*_layers)
for i in range(len(layers))
]
)
# 5. Padding values for spectral conv
if isinstance(padding, int):
@@ -158,14 +144,14 @@ class FourierIntegralKernel(torch.nn.Module):
:return: The output tensor obtained from the kernels convolution.
:rtype: torch.Tensor
"""
if isinstance(x, LabelTensor): # TODO remove when Network is fixed
if isinstance(x, LabelTensor):
warnings.warn(
"LabelTensor passed as input is not allowed,"
" casting LabelTensor to Torch.Tensor"
)
x = x.as_subclass(torch.Tensor)
# permuting the input [batch, channels, x, y, ...]
permutation_idx = [0, x.ndim - 1, *[i for i in range(1, x.ndim - 1)]]
permutation_idx = [0, x.ndim - 1, *list(range(1, x.ndim - 1))]
x = x.permute(permutation_idx)
# padding the input
@@ -179,11 +165,50 @@ class FourierIntegralKernel(torch.nn.Module):
x = x[idxs]
# permuting back [batch, x, y, ..., channels]
permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
permutation_idx = [0, *list(range(2, x.ndim)), 1]
x = x.permute(permutation_idx)
return x
@staticmethod
def _check_consistency(
dimensions,
padding,
padding_type,
inner_size,
n_layers,
func,
layers,
n_modes,
):
check_consistency(dimensions, int)
check_consistency(padding, int)
check_consistency(padding_type, str)
check_consistency(inner_size, int)
check_consistency(n_layers, int)
check_consistency(func, nn.Module, subclass=True)
if layers is not None:
if isinstance(layers, (tuple, list)):
check_consistency(layers, int)
else:
raise ValueError("layers must be tuple or list of int.")
if not isinstance(n_modes, (list, tuple, int)):
raise ValueError(
"n_modes must be a int or list or tuple of valid modes."
" More information on the official documentation."
)
@staticmethod
def _get_fourier_block(dimensions):
if dimensions == 1:
return FourierBlock1D
if dimensions == 2:
return FourierBlock2D
if dimensions == 3:
return FourierBlock3D
raise NotImplementedError("FNO implemented only for 1D/2D/3D data.")
class FNO(KernelNeuralOperator):
"""

View File

@@ -1,6 +1,10 @@
"""
Module for the Graph Neural Operator and Graph Neural Kernel.
"""
import torch
from torch.nn import Tanh
from .block import GNOBlock
from .block.gno_block import GNOBlock
from .kernel_neural_operator import KernelNeuralOperator
@@ -30,14 +34,20 @@ class GraphNeuralKernel(torch.nn.Module):
:type edge_features: int
:param n_layers: The number of kernel layers.
:type n_layers: int
:param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer.
:param internal_n_layers: The number of layers the FF Neural Network
internal to each Kernel Layer.
:type internal_n_layers: int
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
:param internal_layers: Number of neurons of hidden layers(s) in the
FF Neural Network inside for each Kernel Layer.
:type internal_layers: list | tuple
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
:param external_func: The activation function applied to the output of the Graph Integral Layer.
:param internal_func: The activation function used inside the
computation of the representation of the edge features in the
Graph Integral Layer.
:param external_func: The activation function applied to the output of
the Graph Integral Layer.
:type external_func: torch.nn.Module
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
:param shared_weights: If ``True`` the weights of the Graph Integral
Layers are shared.
"""
super().__init__()
if external_func is None:
@@ -56,7 +66,7 @@ class GraphNeuralKernel(torch.nn.Module):
external_func=external_func,
)
self.n_layers = n_layers
self.forward = self.forward_shared
self._forward_func = self._forward_shared
else:
self.layers = torch.nn.ModuleList(
[
@@ -72,25 +82,21 @@ class GraphNeuralKernel(torch.nn.Module):
for _ in range(n_layers)
]
)
self._forward_func = self._forward_unshared
def forward(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Neural Kernel used when the weights are not shared.
:param x: The input batch.
:type x: torch.Tensor
:param edge_index: The edge index.
:type edge_index: torch.Tensor
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor
"""
def _forward_unshared(self, x, edge_index, edge_attr):
for layer in self.layers:
x = layer(x, edge_index, edge_attr)
return x
def forward_shared(self, x, edge_index, edge_attr):
def _forward_shared(self, x, edge_index, edge_attr):
for _ in range(self.n_layers):
x = self.layers(x, edge_index, edge_attr)
return x
def forward(self, x, edge_index, edge_attr):
"""
The forward pass of the Graph Neural Kernel used when the weights are shared.
The forward pass of the Graph Neural Kernel.
:param x: The input batch.
:type x: torch.Tensor
@@ -99,9 +105,7 @@ class GraphNeuralKernel(torch.nn.Module):
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor
"""
for _ in range(self.n_layers):
x = self.layers(x, edge_index, edge_attr)
return x
return self._forward_func(x, edge_index, edge_attr)
class GraphNeuralOperator(KernelNeuralOperator):
@@ -125,23 +129,31 @@ class GraphNeuralOperator(KernelNeuralOperator):
"""
The Graph Neural Operator constructor.
:param lifting_operator: The lifting operator mapping the node features to its hidden dimension.
:param lifting_operator: The lifting operator mapping the node features
to its hidden dimension.
:type lifting_operator: torch.nn.Module
:param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function.
:param projection_operator: The projection operator mapping the hidden
representation of the nodes features to the output function.
:type projection_operator: torch.nn.Module
:param edge_features: Number of edge features.
:type edge_features: int
:param n_layers: The number of kernel layers.
:type n_layers: int
:param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer.
:param internal_n_layers: The number of layers the Feed Forward Neural
Network internal to each Kernel Layer.
:type internal_n_layers: int
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer.
:param internal_layers: Number of neurons of hidden layers(s) in the
FF Neural Network inside for each Kernel Layer.
:type internal_layers: list | tuple
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer.
:param internal_func: The activation function used inside the
computation of the representation of the edge features in the
Graph Integral Layer.
:type internal_func: torch.nn.Module
:param external_func: The activation function applied to the output of the Graph Integral Kernel.
:param external_func: The activation function applied to the output of
the Graph Integral Kernel.
:type external_func: torch.nn.Module
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared.
:param shared_weights: If ``True`` the weights of the Graph Integral
Layers are shared.
:type shared_weights: bool
"""

View File

@@ -1,3 +1,7 @@
"""
Old layers module, deprecated in 0.2.0.
"""
import warnings
from ..block import *
@@ -8,7 +12,7 @@ from ...utils import custom_warning_format
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.warn(
f"'pina.model.layers' is deprecated and will be removed "
f"in future versions. Please use 'pina.model.block' instead.",
"'pina.model.layers' is deprecated and will be removed "
"in future versions. Please use 'pina.model.block' instead.",
DeprecationWarning,
)

View File

@@ -1,7 +1,7 @@
"""Module LowRank Neural Operator."""
import torch
from torch import nn, cat
from torch import nn
from ..utils import check_consistency
@@ -145,4 +145,4 @@ class LowRankNeuralOperator(KernelNeuralOperator):
for module in self._integral_kernels:
x = module(x, coords)
# projecting
return self._projection_operator(cat((x, coords), dim=-1))
return self._projection_operator(torch.cat((x, coords), dim=-1))

View File

@@ -1,11 +1,11 @@
"""Module for Multi FeedForward model"""
from abc import ABC, abstractmethod
import torch
from .feed_forward import FeedForward
class MultiFeedForward(torch.nn.Module):
class MultiFeedForward(torch.nn.Module, ABC):
"""
The PINA implementation of MultiFeedForward network.
@@ -24,3 +24,9 @@ class MultiFeedForward(torch.nn.Module):
for name, constructor_args in ffn_dict.items():
setattr(self, name, FeedForward(**constructor_args))
@abstractmethod
def forward(self, *args, **kwargs):
"""
TODO: Docstring
"""

View File

@@ -5,6 +5,7 @@ from ..utils import check_consistency
class Spline(torch.nn.Module):
"""TODO: Docstring for Spline."""
def __init__(self, order=4, knots=None, control_points=None) -> None:
"""
@@ -99,6 +100,7 @@ class Spline(torch.nn.Module):
@property
def control_points(self):
"""TODO: Docstring for control_points."""
return self._control_points
@control_points.setter
@@ -116,6 +118,7 @@ class Spline(torch.nn.Module):
@property
def knots(self):
"""TODO: Docstring for knots."""
return self._knots
@knots.setter