🎨 Format Python code with psf/black
This commit is contained in:
@@ -9,7 +9,7 @@ __all__ = [
|
|||||||
"FourierBlock2D",
|
"FourierBlock2D",
|
||||||
"FourierBlock3D",
|
"FourierBlock3D",
|
||||||
"PODBlock",
|
"PODBlock",
|
||||||
"PeriodicBoundaryEmbedding"
|
"PeriodicBoundaryEmbedding",
|
||||||
]
|
]
|
||||||
|
|
||||||
from .convolution_2d import ContinuousConvBlock
|
from .convolution_2d import ContinuousConvBlock
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ from pina.utils import check_consistency
|
|||||||
class PeriodicBoundaryEmbedding(torch.nn.Module):
|
class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Imposing hard constraint periodic boundary conditions by embedding the
|
Imposing hard constraint periodic boundary conditions by embedding the
|
||||||
input.
|
input.
|
||||||
|
|
||||||
A periodic function :math:`u:\mathbb{R}^{\rm{in}}
|
A periodic function :math:`u:\mathbb{R}^{\rm{in}}
|
||||||
\rightarrow\mathbb{R}^{\rm{out}}` periodic in the spatial
|
\rightarrow\mathbb{R}^{\rm{out}}` periodic in the spatial
|
||||||
coordinates :math:`\mathbf{x}` with periods :math:`\mathbf{L}` is such that:
|
coordinates :math:`\mathbf{x}` with periods :math:`\mathbf{L}` is such that:
|
||||||
@@ -30,7 +30,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
|||||||
where :math:`\text{dim}(\tilde{\mathbf{x}}) = 3\text{dim}(\mathbf{x})`.
|
where :math:`\text{dim}(\tilde{\mathbf{x}}) = 3\text{dim}(\mathbf{x})`.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
**Original reference**:
|
**Original reference**:
|
||||||
1. Dong, Suchuan, and Naxian Ni (2021). *A method for representing
|
1. Dong, Suchuan, and Naxian Ni (2021). *A method for representing
|
||||||
periodic functions and enforcing exactly periodic boundary
|
periodic functions and enforcing exactly periodic boundary
|
||||||
conditions with deep neural networks*. Journal of Computational
|
conditions with deep neural networks*. Journal of Computational
|
||||||
@@ -52,6 +52,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
|||||||
:math:`>3`. The PINA code is tested only for function PBC and not for
|
:math:`>3`. The PINA code is tested only for function PBC and not for
|
||||||
its derivatives.
|
its derivatives.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dimension, periods, output_dimension=None):
|
def __init__(self, input_dimension, periods, output_dimension=None):
|
||||||
"""
|
"""
|
||||||
:param int input_dimension: The dimension of the input tensor, it can
|
:param int input_dimension: The dimension of the input tensor, it can
|
||||||
@@ -82,16 +83,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
# checks on the periods
|
# checks on the periods
|
||||||
if isinstance(periods, dict):
|
if isinstance(periods, dict):
|
||||||
if not all(isinstance(dim, (str, int)) and
|
if not all(
|
||||||
isinstance(period, (float, int))
|
isinstance(dim, (str, int)) and isinstance(period, (float, int))
|
||||||
for dim, period in periods.items()):
|
for dim, period in periods.items()
|
||||||
raise TypeError('In dictionary periods, keys must be integers'
|
):
|
||||||
' or strings, and values must be float or int.')
|
raise TypeError(
|
||||||
|
"In dictionary periods, keys must be integers"
|
||||||
|
" or strings, and values must be float or int."
|
||||||
|
)
|
||||||
self._period = periods
|
self._period = periods
|
||||||
else:
|
else:
|
||||||
self._period = {k: periods for k in range(input_dimension)}
|
self._period = {k: periods for k in range(input_dimension)}
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Forward pass to compute the periodic boundary conditions embedding.
|
Forward pass to compute the periodic boundary conditions embedding.
|
||||||
@@ -100,14 +103,24 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
|||||||
:return: Fourier embeddings of the input.
|
:return: Fourier embeddings of the input.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
omega = torch.stack([torch.pi * 2. / torch.tensor([val],
|
omega = torch.stack(
|
||||||
device=x.device)
|
[
|
||||||
for val in self._period.values()],
|
torch.pi * 2.0 / torch.tensor([val], device=x.device)
|
||||||
dim=-1)
|
for val in self._period.values()
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
x = self._get_vars(x, list(self._period.keys()))
|
x = self._get_vars(x, list(self._period.keys()))
|
||||||
return self._layer(torch.cat([torch.ones_like(x),
|
return self._layer(
|
||||||
torch.cos(omega * x),
|
torch.cat(
|
||||||
torch.sin(omega * x)], dim=-1))
|
[
|
||||||
|
torch.ones_like(x),
|
||||||
|
torch.cos(omega * x),
|
||||||
|
torch.sin(omega * x),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _get_vars(self, x, indeces):
|
def _get_vars(self, x, indeces):
|
||||||
"""
|
"""
|
||||||
@@ -123,16 +136,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
|
|||||||
return x.extract(indeces)
|
return x.extract(indeces)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Not possible to extract input variables from tensor.'
|
"Not possible to extract input variables from tensor."
|
||||||
' Ensure that the passed tensor is a LabelTensor or'
|
" Ensure that the passed tensor is a LabelTensor or"
|
||||||
' pass list of integers to extract variables. For'
|
" pass list of integers to extract variables. For"
|
||||||
' more information refer to warning in the documentation.')
|
" more information refer to warning in the documentation."
|
||||||
|
)
|
||||||
elif isinstance(indeces[0], int):
|
elif isinstance(indeces[0], int):
|
||||||
return x[..., indeces]
|
return x[..., indeces]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Not able to extract right indeces for tensor.'
|
"Not able to extract right indeces for tensor."
|
||||||
' For more information refer to warning in the documentation.')
|
" For more information refer to warning in the documentation."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def period(self):
|
def period(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user