🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-03-01 17:16:04 +00:00
committed by Nicola Demo
parent 4cfd90b904
commit 22ceee1755
2 changed files with 38 additions and 23 deletions

View File

@@ -9,7 +9,7 @@ __all__ = [
"FourierBlock2D", "FourierBlock2D",
"FourierBlock3D", "FourierBlock3D",
"PODBlock", "PODBlock",
"PeriodicBoundaryEmbedding" "PeriodicBoundaryEmbedding",
] ]
from .convolution_2d import ContinuousConvBlock from .convolution_2d import ContinuousConvBlock

View File

@@ -52,6 +52,7 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
:math:`>3`. The PINA code is tested only for function PBC and not for :math:`>3`. The PINA code is tested only for function PBC and not for
its derivatives. its derivatives.
""" """
def __init__(self, input_dimension, periods, output_dimension=None): def __init__(self, input_dimension, periods, output_dimension=None):
""" """
:param int input_dimension: The dimension of the input tensor, it can :param int input_dimension: The dimension of the input tensor, it can
@@ -82,16 +83,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
# checks on the periods # checks on the periods
if isinstance(periods, dict): if isinstance(periods, dict):
if not all(isinstance(dim, (str, int)) and if not all(
isinstance(period, (float, int)) isinstance(dim, (str, int)) and isinstance(period, (float, int))
for dim, period in periods.items()): for dim, period in periods.items()
raise TypeError('In dictionary periods, keys must be integers' ):
' or strings, and values must be float or int.') raise TypeError(
"In dictionary periods, keys must be integers"
" or strings, and values must be float or int."
)
self._period = periods self._period = periods
else: else:
self._period = {k: periods for k in range(input_dimension)} self._period = {k: periods for k in range(input_dimension)}
def forward(self, x): def forward(self, x):
""" """
Forward pass to compute the periodic boundary conditions embedding. Forward pass to compute the periodic boundary conditions embedding.
@@ -100,14 +103,24 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
:return: Fourier embeddings of the input. :return: Fourier embeddings of the input.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
omega = torch.stack([torch.pi * 2. / torch.tensor([val], omega = torch.stack(
device=x.device) [
for val in self._period.values()], torch.pi * 2.0 / torch.tensor([val], device=x.device)
dim=-1) for val in self._period.values()
],
dim=-1,
)
x = self._get_vars(x, list(self._period.keys())) x = self._get_vars(x, list(self._period.keys()))
return self._layer(torch.cat([torch.ones_like(x), return self._layer(
torch.cat(
[
torch.ones_like(x),
torch.cos(omega * x), torch.cos(omega * x),
torch.sin(omega * x)], dim=-1)) torch.sin(omega * x),
],
dim=-1,
)
)
def _get_vars(self, x, indeces): def _get_vars(self, x, indeces):
""" """
@@ -123,16 +136,18 @@ class PeriodicBoundaryEmbedding(torch.nn.Module):
return x.extract(indeces) return x.extract(indeces)
except AttributeError: except AttributeError:
raise RuntimeError( raise RuntimeError(
'Not possible to extract input variables from tensor.' "Not possible to extract input variables from tensor."
' Ensure that the passed tensor is a LabelTensor or' " Ensure that the passed tensor is a LabelTensor or"
' pass list of integers to extract variables. For' " pass list of integers to extract variables. For"
' more information refer to warning in the documentation.') " more information refer to warning in the documentation."
)
elif isinstance(indeces[0], int): elif isinstance(indeces[0], int):
return x[..., indeces] return x[..., indeces]
else: else:
raise RuntimeError( raise RuntimeError(
'Not able to extract right indeces for tensor.' "Not able to extract right indeces for tensor."
' For more information refer to warning in the documentation.') " For more information refer to warning in the documentation."
)
@property @property
def period(self): def period(self):