🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-03-01 17:16:04 +00:00
committed by Nicola Demo
parent 4cfd90b904
commit 22ceee1755
2 changed files with 38 additions and 23 deletions

View File

@@ -9,7 +9,7 @@ __all__ = [
"FourierBlock2D",
"FourierBlock3D",
"PODBlock",
"PeriodicBoundaryEmbedding"
"PeriodicBoundaryEmbedding",
]
from .convolution_2d import ContinuousConvBlock

View File

@@ -7,8 +7,8 @@ from pina.utils import check_consistency
class PeriodicBoundaryEmbedding(torch.nn.Module):
r"""
Imposing hard constraint periodic boundary conditions by embedding the
input.
input.
A periodic function :math:`u:\mathbb{R}^{\rm{in}}
\rightarrow\mathbb{R}^{\rm{out}}` periodic in the spatial
coordinates :math:`\mathbf{x}` with periods :math:`\mathbf{L}` is such that:
@@ -30,7 +30,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
where :math:`\text{dim}(\tilde{\mathbf{x}}) = 3\text{dim}(\mathbf{x})`.
.. seealso::
**Original reference**:
**Original reference**:
1. Dong, Suchuan, and Naxian Ni (2021). *A method for representing
periodic functions and enforcing exactly periodic boundary
conditions with deep neural networks*. Journal of Computational
@@ -52,6 +52,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
:math:`>3`. The PINA code is tested only for function PBC and not for
its derivatives.
"""
def __init__(self, input_dimension, periods, output_dimension=None):
"""
:param int input_dimension: The dimension of the input tensor, it can
@@ -82,16 +83,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
# checks on the periods
if isinstance(periods, dict):
if not all(isinstance(dim, (str, int)) and
isinstance(period, (float, int))
for dim, period in periods.items()):
raise TypeError('In dictionary periods, keys must be integers'
' or strings, and values must be float or int.')
if not all(
isinstance(dim, (str, int)) and isinstance(period, (float, int))
for dim, period in periods.items()
):
raise TypeError(
"In dictionary periods, keys must be integers"
" or strings, and values must be float or int."
)
self._period = periods
else:
self._period = {k: periods for k in range(input_dimension)}
def forward(self, x):
"""
Forward pass to compute the periodic boundary conditions embedding.
@@ -100,14 +103,24 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
:return: Fourier embeddings of the input.
:rtype: torch.Tensor
"""
omega = torch.stack([torch.pi * 2. / torch.tensor([val],
device=x.device)
for val in self._period.values()],
dim=-1)
omega = torch.stack(
[
torch.pi * 2.0 / torch.tensor([val], device=x.device)
for val in self._period.values()
],
dim=-1,
)
x = self._get_vars(x, list(self._period.keys()))
return self._layer(torch.cat([torch.ones_like(x),
torch.cos(omega * x),
torch.sin(omega * x)], dim=-1))
return self._layer(
torch.cat(
[
torch.ones_like(x),
torch.cos(omega * x),
torch.sin(omega * x),
],
dim=-1,
)
)
def _get_vars(self, x, indeces):
"""
@@ -123,16 +136,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
return x.extract(indeces)
except AttributeError:
raise RuntimeError(
'Not possible to extract input variables from tensor.'
' Ensure that the passed tensor is a LabelTensor or'
' pass list of integers to extract variables. For'
' more information refer to warning in the documentation.')
"Not possible to extract input variables from tensor."
" Ensure that the passed tensor is a LabelTensor or"
" pass list of integers to extract variables. For"
" more information refer to warning in the documentation."
)
elif isinstance(indeces[0], int):
return x[..., indeces]
else:
raise RuntimeError(
'Not able to extract right indeces for tensor.'
' For more information refer to warning in the documentation.')
"Not able to extract right indeces for tensor."
" For more information refer to warning in the documentation."
)
@property
def period(self):