🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-06-11 11:31:22 +00:00
committed by Nicola Demo
parent 6c3adfb03d
commit 7634d22ca7
2 changed files with 19 additions and 17 deletions

View File

@@ -158,18 +158,14 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
return self._period
class FourierFeatureEmbedding(torch.nn.Module):
def __init__(self,
input_dimension,
output_dimension,
sigma):
def __init__(self, input_dimension, output_dimension, sigma):
r"""
Fourier Feature Embedding class for encoding input features
using random Fourier features.This class applies a Fourier
transformation to the input features,
which can help in learning high-frequency variations in data.
If multiple sigma are provided, the class
If multiple sigma are provided, the class
supports multiscale feature embedding, creating embeddings for
each scale specified by the sigma.
@@ -224,18 +220,22 @@ class FourierFeatureEmbedding(torch.nn.Module):
check_consistency(output_dimension, int)
check_consistency(input_dimension, int)
if output_dimension % 2:
raise RuntimeError('Expected output_dimension to be a even number, '
f'got {output_dimension}.')
raise RuntimeError(
"Expected output_dimension to be a even number, "
f"got {output_dimension}."
)
# assign sigma
self._sigma = sigma
# create non-trainable matrices
self._matrix = torch.rand(
size = (input_dimension,
output_dimension // 2),
requires_grad = False
) * self.sigma
self._matrix = (
torch.rand(
size=(input_dimension, output_dimension // 2),
requires_grad=False,
)
* self.sigma
)
def forward(self, x):
"""
@@ -248,12 +248,14 @@ class FourierFeatureEmbedding(torch.nn.Module):
# compute random matrix multiplication
out = torch.mm(x, self._matrix)
# return embedding
return torch.cat([torch.cos(2*torch.pi*out), torch.sin(2*torch.pi*out)], dim=-1)
return torch.cat(
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
dim=-1,
)
@property
def sigma(self):
"""
Returning the variance of the sampled matrix for Fourier Embedding.
"""
return self._sigma
return self._sigma

View File

@@ -246,7 +246,7 @@ class Plotter:
plt.tight_layout()
if title is not None:
plt.title(title)
if filename:
plt.savefig(filename)
plt.close()