* Fourier feature embedding added, and modify typos in doc Periodic Boundary Embedding.
* Fixing doc for Periodic Boundary Embedding. * Creating doc for Fourier Feature Embedding.
This commit is contained in:
committed by
Nicola Demo
parent
89ee010a94
commit
5785b2732c
@@ -1,7 +1,8 @@
|
||||
""" Periodic Boundary Embedding modulus. """
|
||||
""" Embedding modulus. """
|
||||
|
||||
import torch
|
||||
from pina.utils import check_consistency
|
||||
from typing import Union, Sequence
|
||||
|
||||
|
||||
class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
@@ -100,7 +101,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
Forward pass to compute the periodic boundary conditions embedding.
|
||||
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:return: Fourier embeddings of the input.
|
||||
:return: Periodic embedding of the input.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
omega = torch.stack(
|
||||
@@ -155,3 +156,112 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
The period of the periodic function to approximate.
|
||||
"""
|
||||
return self._period
|
||||
|
||||
|
||||
|
||||
class FourierFeatureEmbedding(torch.nn.Module):
|
||||
def __init__(self,
|
||||
input_dimension : int,
|
||||
output_dimension : int,
|
||||
sigmas : Union[float, int, Sequence[float], Sequence[int]],
|
||||
embedding_output_dimension : int = None):
|
||||
r"""
|
||||
Fourier Feature Embedding class for encoding input features
|
||||
using random Fourier features.This class applies a Fourier
|
||||
transformation to the input features,
|
||||
which can help in learning high-frequency variations in data.
|
||||
If multiple sigmas are provided, the class
|
||||
supports multiscale feature embedding, creating embeddings for
|
||||
each scale specified by the sigmas.
|
||||
|
||||
The :obj:`FourierFeatureEmbedding` augments the input
|
||||
by the following formula (3.10 of original paper):
|
||||
|
||||
.. math::
|
||||
\mathbf{x} \rightarrow \tilde{\mathbf{x}} = \left[
|
||||
\cos\left( \mathbf{B} \mathbf{x} \right),
|
||||
\sin\left( \mathbf{B} \mathbf{x} \right)\right],
|
||||
|
||||
where :math:`\mathbf{B}_{ij} \sim \mathcal{N}(0, \sigma^2)`.
|
||||
|
||||
In case multiple ``sigmas`` are passed, the resulting embeddings
|
||||
are concateneted:
|
||||
|
||||
.. math::
|
||||
\mathbf{x} \rightarrow \tilde{\mathbf{x}} = \left[
|
||||
\cos\left( \mathbf{B}^1 \mathbf{x} \right),
|
||||
\sin\left( \mathbf{B}^1 \mathbf{x} \right),
|
||||
\cos\left( \mathbf{B}^2 \mathbf{x} \right),
|
||||
\sin\left( \mathbf{B}^3 \mathbf{x} \right),
|
||||
\dots,
|
||||
\cos\left( \mathbf{B}^M \mathbf{x} \right),
|
||||
\sin\left( \mathbf{B}^M \mathbf{x} \right)\right],
|
||||
|
||||
where :math:`\mathbf{B}^k_{ij} \sim \mathcal{N}(0, \sigma_k^2) \quad
|
||||
k \in (1, \dots, M)`.
|
||||
|
||||
.. seealso::
|
||||
**Original reference**:
|
||||
Wang, Sifan, Hanwen Wang, and Paris Perdikaris. *On the eigenvector
|
||||
bias of Fourier feature networks: From regression to solving
|
||||
multi-scale PDEs with physics-informed neural networks.*
|
||||
Computer Methods in Applied Mechanics and
|
||||
Engineering 384 (2021): 113938.
|
||||
DOI: `10.1016/j.cma.2021.113938.
|
||||
<https://doi.org/10.1016/j.cma.2021.113938>`_
|
||||
|
||||
:param int input_dimension: The input vector dimension of the layer.
|
||||
:param int output_dimension: The output dimension of the layer.
|
||||
:param sigmas: The standard deviation(s) used for the Fourier embedding.
|
||||
This can be a single float or integer, or a sequence of floats
|
||||
or integers. If a sequence is provided, the embedding will be
|
||||
computed for each sigma separately, enabling multiscale embeddings.
|
||||
:type sigmas: Union[float, int, Sequence[float], Sequence[int]]
|
||||
:param int output_dimension: The emebedding output dimension of the
|
||||
random matrix use to compute the fourier feature. If ``None``, it
|
||||
will be the same as ``output_dimension``, default ``None``.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# check consistency
|
||||
check_consistency(sigmas, (int, float))
|
||||
if isinstance(sigmas, (int, float)):
|
||||
sigmas = [sigmas]
|
||||
check_consistency(output_dimension, int)
|
||||
check_consistency(input_dimension, int)
|
||||
|
||||
if embedding_output_dimension is None:
|
||||
embedding_output_dimension = output_dimension
|
||||
check_consistency(embedding_output_dimension, int)
|
||||
|
||||
# assign
|
||||
self.sigmas = sigmas
|
||||
|
||||
# create non-trainable matrices
|
||||
self._matrices = [
|
||||
torch.rand(
|
||||
size = (input_dimension,
|
||||
embedding_output_dimension),
|
||||
requires_grad = False) * sigma for sigma in sigmas
|
||||
]
|
||||
|
||||
# create linear layer to map to the output dimension
|
||||
self._linear = torch.nn.Linear(
|
||||
in_features=2*len(sigmas)*embedding_output_dimension,
|
||||
out_features=output_dimension)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass to compute the fourier embedding.
|
||||
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:return: Fourier embeddings of the input.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# compute random matrix multiplication
|
||||
out = torch.cat([torch.mm(x, m) for m in self._matrices], dim=-1)
|
||||
# compute cos/sin emebedding
|
||||
out = torch.cat([torch.cos(out), torch.sin(out)], dim=-1)
|
||||
# return linear layer mapping
|
||||
return self._linear(out)
|
||||
Reference in New Issue
Block a user