🎨 Format Python code with psf/black
This commit is contained in:
@@ -9,7 +9,7 @@ __all__ = [
|
||||
"FourierBlock2D",
|
||||
"FourierBlock3D",
|
||||
"PODBlock",
|
||||
"PeriodicBoundaryEmbedding"
|
||||
"PeriodicBoundaryEmbedding",
|
||||
]
|
||||
|
||||
from .convolution_2d import ContinuousConvBlock
|
||||
|
||||
@@ -7,8 +7,8 @@ from pina.utils import check_consistency
|
||||
class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
r"""
|
||||
Imposing hard constraint periodic boundary conditions by embedding the
|
||||
input.
|
||||
|
||||
input.
|
||||
|
||||
A periodic function :math:`u:\mathbb{R}^{\rm{in}}
|
||||
\rightarrow\mathbb{R}^{\rm{out}}` periodic in the spatial
|
||||
coordinates :math:`\mathbf{x}` with periods :math:`\mathbf{L}` is such that:
|
||||
@@ -30,7 +30,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
where :math:`\text{dim}(\tilde{\mathbf{x}}) = 3\text{dim}(\mathbf{x})`.
|
||||
|
||||
.. seealso::
|
||||
**Original reference**:
|
||||
**Original reference**:
|
||||
1. Dong, Suchuan, and Naxian Ni (2021). *A method for representing
|
||||
periodic functions and enforcing exactly periodic boundary
|
||||
conditions with deep neural networks*. Journal of Computational
|
||||
@@ -52,6 +52,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
:math:`>3`. The PINA code is tested only for function PBC and not for
|
||||
its derivatives.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dimension, periods, output_dimension=None):
|
||||
"""
|
||||
:param int input_dimension: The dimension of the input tensor, it can
|
||||
@@ -82,16 +83,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
|
||||
# checks on the periods
|
||||
if isinstance(periods, dict):
|
||||
if not all(isinstance(dim, (str, int)) and
|
||||
isinstance(period, (float, int))
|
||||
for dim, period in periods.items()):
|
||||
raise TypeError('In dictionary periods, keys must be integers'
|
||||
' or strings, and values must be float or int.')
|
||||
if not all(
|
||||
isinstance(dim, (str, int)) and isinstance(period, (float, int))
|
||||
for dim, period in periods.items()
|
||||
):
|
||||
raise TypeError(
|
||||
"In dictionary periods, keys must be integers"
|
||||
" or strings, and values must be float or int."
|
||||
)
|
||||
self._period = periods
|
||||
else:
|
||||
self._period = {k: periods for k in range(input_dimension)}
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass to compute the periodic boundary conditions embedding.
|
||||
@@ -100,14 +103,24 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
:return: Fourier embeddings of the input.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
omega = torch.stack([torch.pi * 2. / torch.tensor([val],
|
||||
device=x.device)
|
||||
for val in self._period.values()],
|
||||
dim=-1)
|
||||
omega = torch.stack(
|
||||
[
|
||||
torch.pi * 2.0 / torch.tensor([val], device=x.device)
|
||||
for val in self._period.values()
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
x = self._get_vars(x, list(self._period.keys()))
|
||||
return self._layer(torch.cat([torch.ones_like(x),
|
||||
torch.cos(omega * x),
|
||||
torch.sin(omega * x)], dim=-1))
|
||||
return self._layer(
|
||||
torch.cat(
|
||||
[
|
||||
torch.ones_like(x),
|
||||
torch.cos(omega * x),
|
||||
torch.sin(omega * x),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_vars(self, x, indeces):
|
||||
"""
|
||||
@@ -123,16 +136,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||
return x.extract(indeces)
|
||||
except AttributeError:
|
||||
raise RuntimeError(
|
||||
'Not possible to extract input variables from tensor.'
|
||||
' Ensure that the passed tensor is a LabelTensor or'
|
||||
' pass list of integers to extract variables. For'
|
||||
' more information refer to warning in the documentation.')
|
||||
"Not possible to extract input variables from tensor."
|
||||
" Ensure that the passed tensor is a LabelTensor or"
|
||||
" pass list of integers to extract variables. For"
|
||||
" more information refer to warning in the documentation."
|
||||
)
|
||||
elif isinstance(indeces[0], int):
|
||||
return x[..., indeces]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Not able to extract right indeces for tensor.'
|
||||
' For more information refer to warning in the documentation.')
|
||||
"Not able to extract right indeces for tensor."
|
||||
" For more information refer to warning in the documentation."
|
||||
)
|
||||
|
||||
@property
|
||||
def period(self):
|
||||
|
||||
Reference in New Issue
Block a user