fix embedding by removing linear layer

This commit is contained in:
Monthly Tag bot
2024-06-06 15:00:30 +02:00
committed by Nicola Demo
parent 5785b2732c
commit 4b64998f45
2 changed files with 45 additions and 61 deletions

View File

@@ -161,18 +161,17 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
class FourierFeatureEmbedding(torch.nn.Module):
def __init__(self,
input_dimension : int,
output_dimension : int,
sigmas : Union[float, int, Sequence[float], Sequence[int]],
embedding_output_dimension : int = None):
input_dimension,
output_dimension,
sigma):
r"""
Fourier Feature Embedding class for encoding input features
using random Fourier features.This class applies a Fourier
transformation to the input features,
which can help in learning high-frequency variations in data.
If multiple sigmas are provided, the class
If multiple sigma are provided, the class
supports multiscale feature embedding, creating embeddings for
each scale specified by the sigmas.
each scale specified by the sigma.
The :obj:`FourierFeatureEmbedding` augments the input
by the following formula (3.10 of original paper):
@@ -184,7 +183,7 @@ class FourierFeatureEmbedding(torch.nn.Module):
where :math:`\mathbf{B}_{ij} \sim \mathcal{N}(0, \sigma^2)`.
In case multiple ``sigmas`` are passed, the resulting embeddings
In case multiple ``sigma`` are passed, the resulting embeddings
are concateneted:
.. math::
@@ -211,45 +210,32 @@ class FourierFeatureEmbedding(torch.nn.Module):
<https://doi.org/10.1016/j.cma.2021.113938>`_
:param int input_dimension: The input vector dimension of the layer.
:param int output_dimension: The output dimension of the layer.
:param sigmas: The standard deviation(s) used for the Fourier embedding.
This can be a single float or integer, or a sequence of floats
or integers. If a sequence is provided, the embedding will be
computed for each sigma separately, enabling multiscale embeddings.
:type sigmas: Union[float, int, Sequence[float], Sequence[int]]
:param int output_dimension: The emebedding output dimension of the
random matrix use to compute the fourier feature. If ``None``, it
will be the same as ``output_dimension``, default ``None``.
:param int output_dimension: The output dimension of the layer. The
output is obtained as a concatenation of the cosine and sine
embedding, hence it must be a multiple of two (even number).
:param int | float sigma: The standard deviation used for the
Fourier Embedding. This value must reflect the granularity of the
scale in the differential equation solution.
"""
super().__init__()
# check consistency
check_consistency(sigmas, (int, float))
if isinstance(sigmas, (int, float)):
sigmas = [sigmas]
check_consistency(sigma, (int, float))
check_consistency(output_dimension, int)
check_consistency(input_dimension, int)
if output_dimension % 2:
raise RuntimeError('Expected output_dimension to be a even number, '
f'got {output_dimension}.')
if embedding_output_dimension is None:
embedding_output_dimension = output_dimension
check_consistency(embedding_output_dimension, int)
# assign
self.sigmas = sigmas
# assign sigma
self._sigma = sigma
# create non-trainable matrices
self._matrices = [
torch.rand(
size = (input_dimension,
embedding_output_dimension),
requires_grad = False) * sigma for sigma in sigmas
]
# create linear layer to map to the output dimension
self._linear = torch.nn.Linear(
in_features=2*len(sigmas)*embedding_output_dimension,
out_features=output_dimension)
self._matrix = torch.rand(
size = (input_dimension,
output_dimension // 2),
requires_grad = False
)
def forward(self, x):
"""
@@ -260,8 +246,14 @@ class FourierFeatureEmbedding(torch.nn.Module):
:rtype: torch.Tensor
"""
# compute random matrix multiplication
out = torch.cat([torch.mm(x, m) for m in self._matrices], dim=-1)
# compute cos/sin emebedding
out = torch.cat([torch.cos(out), torch.sin(out)], dim=-1)
# return linear layer mapping
return self._linear(out)
out = torch.mm(x, self._matrix)
# return embedding
return torch.cat([torch.cos(out), torch.sin(out)], dim=-1)
@property
def sigma(self):
"""
Returning the variance of the sampled matrix for Fourier Embedding.
"""
return self._sigma