🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-06-11 11:31:22 +00:00
committed by Nicola Demo
parent 6c3adfb03d
commit 7634d22ca7
2 changed files with 19 additions and 17 deletions

View File

@@ -158,18 +158,14 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
return self._period return self._period
class FourierFeatureEmbedding(torch.nn.Module): class FourierFeatureEmbedding(torch.nn.Module):
def __init__(self, def __init__(self, input_dimension, output_dimension, sigma):
input_dimension,
output_dimension,
sigma):
r""" r"""
Fourier Feature Embedding class for encoding input features Fourier Feature Embedding class for encoding input features
using random Fourier features.This class applies a Fourier using random Fourier features.This class applies a Fourier
transformation to the input features, transformation to the input features,
which can help in learning high-frequency variations in data. which can help in learning high-frequency variations in data.
If multiple sigma are provided, the class If multiple sigma are provided, the class
supports multiscale feature embedding, creating embeddings for supports multiscale feature embedding, creating embeddings for
each scale specified by the sigma. each scale specified by the sigma.
@@ -224,18 +220,22 @@ class FourierFeatureEmbedding(torch.nn.Module):
check_consistency(output_dimension, int) check_consistency(output_dimension, int)
check_consistency(input_dimension, int) check_consistency(input_dimension, int)
if output_dimension % 2: if output_dimension % 2:
raise RuntimeError('Expected output_dimension to be a even number, ' raise RuntimeError(
f'got {output_dimension}.') "Expected output_dimension to be a even number, "
f"got {output_dimension}."
)
# assign sigma # assign sigma
self._sigma = sigma self._sigma = sigma
# create non-trainable matrices # create non-trainable matrices
self._matrix = torch.rand( self._matrix = (
size = (input_dimension, torch.rand(
output_dimension // 2), size=(input_dimension, output_dimension // 2),
requires_grad = False requires_grad=False,
) * self.sigma )
* self.sigma
)
def forward(self, x): def forward(self, x):
""" """
@@ -248,12 +248,14 @@ class FourierFeatureEmbedding(torch.nn.Module):
# compute random matrix multiplication # compute random matrix multiplication
out = torch.mm(x, self._matrix) out = torch.mm(x, self._matrix)
# return embedding # return embedding
return torch.cat([torch.cos(2*torch.pi*out), torch.sin(2*torch.pi*out)], dim=-1) return torch.cat(
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
dim=-1,
)
@property @property
def sigma(self): def sigma(self):
""" """
Returning the variance of the sampled matrix for Fourier Embedding. Returning the variance of the sampled matrix for Fourier Embedding.
""" """
return self._sigma return self._sigma

View File

@@ -246,7 +246,7 @@ class Plotter:
plt.tight_layout() plt.tight_layout()
if title is not None: if title is not None:
plt.title(title) plt.title(title)
if filename: if filename:
plt.savefig(filename) plt.savefig(filename)
plt.close() plt.close()