🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-06-11 11:31:22 +00:00
committed by Nicola Demo
parent 6c3adfb03d
commit 7634d22ca7
2 changed files with 19 additions and 17 deletions

View File

@@ -158,12 +158,8 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
return self._period
class FourierFeatureEmbedding(torch.nn.Module):
def __init__(self,
input_dimension,
output_dimension,
sigma):
def __init__(self, input_dimension, output_dimension, sigma):
r"""
Fourier Feature Embedding class for encoding input features
using random Fourier features.This class applies a Fourier
@@ -224,18 +220,22 @@ class FourierFeatureEmbedding(torch.nn.Module):
check_consistency(output_dimension, int)
check_consistency(input_dimension, int)
if output_dimension % 2:
raise RuntimeError('Expected output_dimension to be a even number, '
f'got {output_dimension}.')
raise RuntimeError(
"Expected output_dimension to be a even number, "
f"got {output_dimension}."
)
# assign sigma
self._sigma = sigma
# create non-trainable matrices
self._matrix = torch.rand(
size = (input_dimension,
output_dimension // 2),
requires_grad = False
) * self.sigma
self._matrix = (
torch.rand(
size=(input_dimension, output_dimension // 2),
requires_grad=False,
)
* self.sigma
)
def forward(self, x):
"""
@@ -248,8 +248,10 @@ class FourierFeatureEmbedding(torch.nn.Module):
# compute random matrix multiplication
out = torch.mm(x, self._matrix)
# return embedding
return torch.cat([torch.cos(2*torch.pi*out), torch.sin(2*torch.pi*out)], dim=-1)
return torch.cat(
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
dim=-1,
)
@property
def sigma(self):