🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-06-11 11:31:22 +00:00
committed by Nicola Demo
parent 6c3adfb03d
commit 7634d22ca7
2 changed files with 19 additions and 17 deletions

View File

@@ -158,12 +158,8 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
return self._period return self._period
class FourierFeatureEmbedding(torch.nn.Module): class FourierFeatureEmbedding(torch.nn.Module):
def __init__(self, def __init__(self, input_dimension, output_dimension, sigma):
input_dimension,
output_dimension,
sigma):
r""" r"""
Fourier Feature Embedding class for encoding input features Fourier Feature Embedding class for encoding input features
using random Fourier features.This class applies a Fourier using random Fourier features.This class applies a Fourier
@@ -224,18 +220,22 @@ class FourierFeatureEmbedding(torch.nn.Module):
check_consistency(output_dimension, int) check_consistency(output_dimension, int)
check_consistency(input_dimension, int) check_consistency(input_dimension, int)
if output_dimension % 2: if output_dimension % 2:
raise RuntimeError('Expected output_dimension to be a even number, ' raise RuntimeError(
f'got {output_dimension}.') "Expected output_dimension to be a even number, "
f"got {output_dimension}."
)
# assign sigma # assign sigma
self._sigma = sigma self._sigma = sigma
# create non-trainable matrices # create non-trainable matrices
self._matrix = torch.rand( self._matrix = (
size = (input_dimension, torch.rand(
output_dimension // 2), size=(input_dimension, output_dimension // 2),
requires_grad = False requires_grad=False,
) * self.sigma )
* self.sigma
)
def forward(self, x): def forward(self, x):
""" """
@@ -248,8 +248,10 @@ class FourierFeatureEmbedding(torch.nn.Module):
# compute random matrix multiplication # compute random matrix multiplication
out = torch.mm(x, self._matrix) out = torch.mm(x, self._matrix)
# return embedding # return embedding
return torch.cat([torch.cos(2*torch.pi*out), torch.sin(2*torch.pi*out)], dim=-1) return torch.cat(
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
dim=-1,
)
@property @property
def sigma(self): def sigma(self):