Format Python code with psf/black push (#273)

* 🎨 Format Python code with psf/black

---------

Co-authored-by: ndem0 <ndem0@users.noreply.github.com>
Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
github-actions[bot]
2024-04-02 10:17:45 +02:00
committed by GitHub
parent cddb191fe4
commit 1d1d767317
2 changed files with 21 additions and 20 deletions

View File

@@ -11,7 +11,7 @@ __all__ = [
"PODBlock",
"PeriodicBoundaryEmbedding",
"AVNOBlock",
"AdaptiveActivationFunction"
"AdaptiveActivationFunction",
]
from .convolution_2d import ContinuousConvBlock

View File

@@ -40,7 +40,7 @@ class AdaptiveActivationFunction(torch.nn.Module):
Parameter containing:
tensor(1., requires_grad=True)
>>>
.. seealso::
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
@@ -50,7 +50,7 @@ class AdaptiveActivationFunction(torch.nn.Module):
knowledge engineering and knowledge management (IC3K).
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
<https://arxiv.org/abs/1602.01321>`_.
"""
def __init__(self, func, alpha=None, beta=None, gamma=None, fixed=None):
@@ -77,17 +77,18 @@ class AdaptiveActivationFunction(torch.nn.Module):
# see if there are fixed variables
if fixed is not None:
check_consistency(fixed, str)
if not all(key in ['alpha', 'beta', 'gamma'] for key in fixed):
raise TypeError("Fixed keys must be in "
"['alpha', 'beta', 'gamma'].")
if not all(key in ["alpha", "beta", "gamma"] for key in fixed):
raise TypeError(
"Fixed keys must be in [`alpha`, `beta`, `gamma`]."
)
# initialize alpha, beta, gamma if they are None
if alpha is None:
alpha = 1.
alpha = 1.0
if beta is None:
beta = 1.
beta = 1.0
if gamma is None:
gamma = 0.
gamma = 0.0
# checking consistency
check_consistency(alpha, (float, complex))
@@ -104,20 +105,20 @@ class AdaptiveActivationFunction(torch.nn.Module):
# setting not fixed variables as torch.nn.Parameter with gradient
# registering the buffer for the one which are fixed, buffers by
# default are saved alongside trainable parameters
if 'alpha' not in (fixed or []):
if "alpha" not in (fixed or []):
self._alpha = torch.nn.Parameter(alpha, requires_grad=True)
else:
self.register_buffer('alpha', alpha)
if 'beta' not in (fixed or []):
self.register_buffer("alpha", alpha)
if "beta" not in (fixed or []):
self._beta = torch.nn.Parameter(beta, requires_grad=True)
else:
self.register_buffer('beta', beta)
if 'gamma' not in (fixed or []):
self.register_buffer("beta", beta)
if "gamma" not in (fixed or []):
self._gamma = torch.nn.Parameter(gamma, requires_grad=True)
else:
self.register_buffer('gamma', gamma)
self.register_buffer("gamma", gamma)
# registering function
self._func = func
@@ -128,21 +129,21 @@ class AdaptiveActivationFunction(torch.nn.Module):
Applies the function to the input elementwise.
"""
return self.alpha * (self._func(self.beta * x + self.gamma))
@property
def alpha(self):
"""
The alpha variable
"""
return self._alpha
@property
def beta(self):
"""
The alpha variable
"""
return self._beta
@property
def gamma(self):
"""