🎨 Format Python code with psf/black
This commit is contained in:
@@ -158,12 +158,8 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
|||||||
return self._period
|
return self._period
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FourierFeatureEmbedding(torch.nn.Module):
|
class FourierFeatureEmbedding(torch.nn.Module):
|
||||||
def __init__(self,
|
def __init__(self, input_dimension, output_dimension, sigma):
|
||||||
input_dimension,
|
|
||||||
output_dimension,
|
|
||||||
sigma):
|
|
||||||
r"""
|
r"""
|
||||||
Fourier Feature Embedding class for encoding input features
|
Fourier Feature Embedding class for encoding input features
|
||||||
using random Fourier features.This class applies a Fourier
|
using random Fourier features.This class applies a Fourier
|
||||||
@@ -224,18 +220,22 @@ class FourierFeatureEmbedding(torch.nn.Module):
|
|||||||
check_consistency(output_dimension, int)
|
check_consistency(output_dimension, int)
|
||||||
check_consistency(input_dimension, int)
|
check_consistency(input_dimension, int)
|
||||||
if output_dimension % 2:
|
if output_dimension % 2:
|
||||||
raise RuntimeError('Expected output_dimension to be a even number, '
|
raise RuntimeError(
|
||||||
f'got {output_dimension}.')
|
"Expected output_dimension to be a even number, "
|
||||||
|
f"got {output_dimension}."
|
||||||
|
)
|
||||||
|
|
||||||
# assign sigma
|
# assign sigma
|
||||||
self._sigma = sigma
|
self._sigma = sigma
|
||||||
|
|
||||||
# create non-trainable matrices
|
# create non-trainable matrices
|
||||||
self._matrix = torch.rand(
|
self._matrix = (
|
||||||
size = (input_dimension,
|
torch.rand(
|
||||||
output_dimension // 2),
|
size=(input_dimension, output_dimension // 2),
|
||||||
requires_grad = False
|
requires_grad=False,
|
||||||
) * self.sigma
|
)
|
||||||
|
* self.sigma
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
@@ -248,8 +248,10 @@ class FourierFeatureEmbedding(torch.nn.Module):
|
|||||||
# compute random matrix multiplication
|
# compute random matrix multiplication
|
||||||
out = torch.mm(x, self._matrix)
|
out = torch.mm(x, self._matrix)
|
||||||
# return embedding
|
# return embedding
|
||||||
return torch.cat([torch.cos(2*torch.pi*out), torch.sin(2*torch.pi*out)], dim=-1)
|
return torch.cat(
|
||||||
|
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sigma(self):
|
def sigma(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user