diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 24cad70..77ee587 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -9,7 +9,7 @@ __all__ = [ "FourierBlock2D", "FourierBlock3D", "PODBlock", - "PeriodicBoundaryEmbedding" + "PeriodicBoundaryEmbedding", ] from .convolution_2d import ContinuousConvBlock diff --git a/pina/model/layers/embedding.py b/pina/model/layers/embedding.py index fd90a27..8e623df 100644 --- a/pina/model/layers/embedding.py +++ b/pina/model/layers/embedding.py @@ -7,8 +7,8 @@ from pina.utils import check_consistency class PeriodicBoundaryEmbedding(torch.nn.Module): r""" Imposing hard constraint periodic boundary conditions by embedding the - input. - + input. + A periodic function :math:`u:\mathbb{R}^{\rm{in}} \rightarrow\mathbb{R}^{\rm{out}}` periodic in the spatial coordinates :math:`\mathbf{x}` with periods :math:`\mathbf{L}` is such that: @@ -30,7 +30,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module): where :math:`\text{dim}(\tilde{\mathbf{x}}) = 3\text{dim}(\mathbf{x})`. .. seealso:: - **Original reference**: + **Original reference**: 1. Dong, Suchuan, and Naxian Ni (2021). *A method for representing periodic functions and enforcing exactly periodic boundary conditions with deep neural networks*. Journal of Computational @@ -52,6 +52,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module): :math:`>3`. The PINA code is tested only for function PBC and not for its derivatives. """ + def __init__(self, input_dimension, periods, output_dimension=None): """ :param int input_dimension: The dimension of the input tensor, it can @@ -82,16 +83,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module): # checks on the periods if isinstance(periods, dict): - if not all(isinstance(dim, (str, int)) and - isinstance(period, (float, int)) - for dim, period in periods.items()): - raise TypeError('In dictionary periods, keys must be integers' - ' or strings, and values must be float or int.') + if not all( + isinstance(dim, (str, int)) and isinstance(period, (float, int)) + for dim, period in periods.items() + ): + raise TypeError( + "In dictionary periods, keys must be integers" + " or strings, and values must be float or int." + ) self._period = periods else: self._period = {k: periods for k in range(input_dimension)} - def forward(self, x): """ Forward pass to compute the periodic boundary conditions embedding. @@ -100,14 +103,24 @@ class PeriodicBoundaryEmbedding(torch.nn.Module): :return: Fourier embeddings of the input. :rtype: torch.Tensor """ - omega = torch.stack([torch.pi * 2. / torch.tensor([val], - device=x.device) - for val in self._period.values()], - dim=-1) + omega = torch.stack( + [ + torch.pi * 2.0 / torch.tensor([val], device=x.device) + for val in self._period.values() + ], + dim=-1, + ) x = self._get_vars(x, list(self._period.keys())) - return self._layer(torch.cat([torch.ones_like(x), - torch.cos(omega * x), - torch.sin(omega * x)], dim=-1)) + return self._layer( + torch.cat( + [ + torch.ones_like(x), + torch.cos(omega * x), + torch.sin(omega * x), + ], + dim=-1, + ) + ) def _get_vars(self, x, indeces): """ @@ -123,16 +136,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module): return x.extract(indeces) except AttributeError: raise RuntimeError( - 'Not possible to extract input variables from tensor.' - ' Ensure that the passed tensor is a LabelTensor or' - ' pass list of integers to extract variables. For' - ' more information refer to warning in the documentation.') + "Not possible to extract input variables from tensor." + " Ensure that the passed tensor is a LabelTensor or" + " pass list of integers to extract variables. For" + " more information refer to warning in the documentation." + ) elif isinstance(indeces[0], int): return x[..., indeces] else: raise RuntimeError( - 'Not able to extract right indeces for tensor.' - ' For more information refer to warning in the documentation.') + "Not able to extract right indeces for tensor." + " For more information refer to warning in the documentation." + ) @property def period(self):