🎨 Format Python code with psf/black
This commit is contained in:
@@ -158,18 +158,14 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
return self._period
|
||||
|
||||
|
||||
|
||||
class FourierFeatureEmbedding(torch.nn.Module):
|
||||
def __init__(self,
|
||||
input_dimension,
|
||||
output_dimension,
|
||||
sigma):
|
||||
def __init__(self, input_dimension, output_dimension, sigma):
|
||||
r"""
|
||||
Fourier Feature Embedding class for encoding input features
|
||||
using random Fourier features.This class applies a Fourier
|
||||
transformation to the input features,
|
||||
which can help in learning high-frequency variations in data.
|
||||
If multiple sigma are provided, the class
|
||||
If multiple sigma are provided, the class
|
||||
supports multiscale feature embedding, creating embeddings for
|
||||
each scale specified by the sigma.
|
||||
|
||||
@@ -224,18 +220,22 @@ class FourierFeatureEmbedding(torch.nn.Module):
|
||||
check_consistency(output_dimension, int)
|
||||
check_consistency(input_dimension, int)
|
||||
if output_dimension % 2:
|
||||
raise RuntimeError('Expected output_dimension to be a even number, '
|
||||
f'got {output_dimension}.')
|
||||
raise RuntimeError(
|
||||
"Expected output_dimension to be a even number, "
|
||||
f"got {output_dimension}."
|
||||
)
|
||||
|
||||
# assign sigma
|
||||
self._sigma = sigma
|
||||
|
||||
# create non-trainable matrices
|
||||
self._matrix = torch.rand(
|
||||
size = (input_dimension,
|
||||
output_dimension // 2),
|
||||
requires_grad = False
|
||||
) * self.sigma
|
||||
self._matrix = (
|
||||
torch.rand(
|
||||
size=(input_dimension, output_dimension // 2),
|
||||
requires_grad=False,
|
||||
)
|
||||
* self.sigma
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -248,12 +248,14 @@ class FourierFeatureEmbedding(torch.nn.Module):
|
||||
# compute random matrix multiplication
|
||||
out = torch.mm(x, self._matrix)
|
||||
# return embedding
|
||||
return torch.cat([torch.cos(2*torch.pi*out), torch.sin(2*torch.pi*out)], dim=-1)
|
||||
|
||||
return torch.cat(
|
||||
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
@property
|
||||
def sigma(self):
|
||||
"""
|
||||
Returning the variance of the sampled matrix for Fourier Embedding.
|
||||
"""
|
||||
return self._sigma
|
||||
return self._sigma
|
||||
|
||||
@@ -246,7 +246,7 @@ class Plotter:
|
||||
plt.tight_layout()
|
||||
if title is not None:
|
||||
plt.title(title)
|
||||
|
||||
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
||||
|
||||
Reference in New Issue
Block a user