diff --git a/pina/model/layers/embedding.py b/pina/model/layers/embedding.py index 2729665..c9f41da 100644 --- a/pina/model/layers/embedding.py +++ b/pina/model/layers/embedding.py @@ -158,18 +158,14 @@ class PeriodicBoundaryEmbedding(torch.nn.Module): return self._period - class FourierFeatureEmbedding(torch.nn.Module): - def __init__(self, - input_dimension, - output_dimension, - sigma): + def __init__(self, input_dimension, output_dimension, sigma): r""" Fourier Feature Embedding class for encoding input features using random Fourier features.This class applies a Fourier transformation to the input features, which can help in learning high-frequency variations in data. - If multiple sigma are provided, the class + If multiple sigma are provided, the class supports multiscale feature embedding, creating embeddings for each scale specified by the sigma. @@ -224,18 +220,22 @@ class FourierFeatureEmbedding(torch.nn.Module): check_consistency(output_dimension, int) check_consistency(input_dimension, int) if output_dimension % 2: - raise RuntimeError('Expected output_dimension to be a even number, ' - f'got {output_dimension}.') + raise RuntimeError( + "Expected output_dimension to be a even number, " + f"got {output_dimension}." + ) # assign sigma self._sigma = sigma # create non-trainable matrices - self._matrix = torch.rand( - size = (input_dimension, - output_dimension // 2), - requires_grad = False - ) * self.sigma + self._matrix = ( + torch.rand( + size=(input_dimension, output_dimension // 2), + requires_grad=False, + ) + * self.sigma + ) def forward(self, x): """ @@ -248,12 +248,14 @@ class FourierFeatureEmbedding(torch.nn.Module): # compute random matrix multiplication out = torch.mm(x, self._matrix) # return embedding - return torch.cat([torch.cos(2*torch.pi*out), torch.sin(2*torch.pi*out)], dim=-1) - + return torch.cat( + [torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)], + dim=-1, + ) @property def sigma(self): """ Returning the variance of the sampled matrix for Fourier Embedding. """ - return self._sigma \ No newline at end of file + return self._sigma diff --git a/pina/plotter.py b/pina/plotter.py index 5d9a94a..041ef05 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -246,7 +246,7 @@ class Plotter: plt.tight_layout() if title is not None: plt.title(title) - + if filename: plt.savefig(filename) plt.close()