Format Python code with psf/black push (#273)

* 🎨 Format Python code with psf/black

---------

Co-authored-by: ndem0 <ndem0@users.noreply.github.com>
Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
github-actions[bot]
2024-04-02 10:17:45 +02:00
committed by GitHub
parent cddb191fe4
commit 1d1d767317
2 changed files with 21 additions and 20 deletions

View File

@@ -11,7 +11,7 @@ __all__ = [
"PODBlock", "PODBlock",
"PeriodicBoundaryEmbedding", "PeriodicBoundaryEmbedding",
"AVNOBlock", "AVNOBlock",
"AdaptiveActivationFunction" "AdaptiveActivationFunction",
] ]
from .convolution_2d import ContinuousConvBlock from .convolution_2d import ContinuousConvBlock

View File

@@ -77,17 +77,18 @@ class AdaptiveActivationFunction(torch.nn.Module):
# see if there are fixed variables # see if there are fixed variables
if fixed is not None: if fixed is not None:
check_consistency(fixed, str) check_consistency(fixed, str)
if not all(key in ['alpha', 'beta', 'gamma'] for key in fixed): if not all(key in ["alpha", "beta", "gamma"] for key in fixed):
raise TypeError("Fixed keys must be in " raise TypeError(
"['alpha', 'beta', 'gamma'].") "Fixed keys must be in [`alpha`, `beta`, `gamma`]."
)
# initialize alpha, beta, gamma if they are None # initialize alpha, beta, gamma if they are None
if alpha is None: if alpha is None:
alpha = 1. alpha = 1.0
if beta is None: if beta is None:
beta = 1. beta = 1.0
if gamma is None: if gamma is None:
gamma = 0. gamma = 0.0
# checking consistency # checking consistency
check_consistency(alpha, (float, complex)) check_consistency(alpha, (float, complex))
@@ -104,20 +105,20 @@ class AdaptiveActivationFunction(torch.nn.Module):
# setting not fixed variables as torch.nn.Parameter with gradient # setting not fixed variables as torch.nn.Parameter with gradient
# registering the buffer for the one which are fixed, buffers by # registering the buffer for the one which are fixed, buffers by
# default are saved alongside trainable parameters # default are saved alongside trainable parameters
if 'alpha' not in (fixed or []): if "alpha" not in (fixed or []):
self._alpha = torch.nn.Parameter(alpha, requires_grad=True) self._alpha = torch.nn.Parameter(alpha, requires_grad=True)
else: else:
self.register_buffer('alpha', alpha) self.register_buffer("alpha", alpha)
if 'beta' not in (fixed or []): if "beta" not in (fixed or []):
self._beta = torch.nn.Parameter(beta, requires_grad=True) self._beta = torch.nn.Parameter(beta, requires_grad=True)
else: else:
self.register_buffer('beta', beta) self.register_buffer("beta", beta)
if 'gamma' not in (fixed or []): if "gamma" not in (fixed or []):
self._gamma = torch.nn.Parameter(gamma, requires_grad=True) self._gamma = torch.nn.Parameter(gamma, requires_grad=True)
else: else:
self.register_buffer('gamma', gamma) self.register_buffer("gamma", gamma)
# registering function # registering function
self._func = func self._func = func