Format Python code with psf/black push (#273)
* 🎨 Format Python code with psf/black --------- Co-authored-by: ndem0 <ndem0@users.noreply.github.com> Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
committed by
GitHub
parent
cddb191fe4
commit
1d1d767317
@@ -11,7 +11,7 @@ __all__ = [
|
|||||||
"PODBlock",
|
"PODBlock",
|
||||||
"PeriodicBoundaryEmbedding",
|
"PeriodicBoundaryEmbedding",
|
||||||
"AVNOBlock",
|
"AVNOBlock",
|
||||||
"AdaptiveActivationFunction"
|
"AdaptiveActivationFunction",
|
||||||
]
|
]
|
||||||
|
|
||||||
from .convolution_2d import ContinuousConvBlock
|
from .convolution_2d import ContinuousConvBlock
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class AdaptiveActivationFunction(torch.nn.Module):
|
|||||||
Parameter containing:
|
Parameter containing:
|
||||||
tensor(1., requires_grad=True)
|
tensor(1., requires_grad=True)
|
||||||
>>>
|
>>>
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||||
@@ -50,7 +50,7 @@ class AdaptiveActivationFunction(torch.nn.Module):
|
|||||||
knowledge engineering and knowledge management (IC3K).
|
knowledge engineering and knowledge management (IC3K).
|
||||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||||
<https://arxiv.org/abs/1602.01321>`_.
|
<https://arxiv.org/abs/1602.01321>`_.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, func, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, func, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
@@ -77,17 +77,18 @@ class AdaptiveActivationFunction(torch.nn.Module):
|
|||||||
# see if there are fixed variables
|
# see if there are fixed variables
|
||||||
if fixed is not None:
|
if fixed is not None:
|
||||||
check_consistency(fixed, str)
|
check_consistency(fixed, str)
|
||||||
if not all(key in ['alpha', 'beta', 'gamma'] for key in fixed):
|
if not all(key in ["alpha", "beta", "gamma"] for key in fixed):
|
||||||
raise TypeError("Fixed keys must be in "
|
raise TypeError(
|
||||||
"['alpha', 'beta', 'gamma'].")
|
"Fixed keys must be in [`alpha`, `beta`, `gamma`]."
|
||||||
|
)
|
||||||
|
|
||||||
# initialize alpha, beta, gamma if they are None
|
# initialize alpha, beta, gamma if they are None
|
||||||
if alpha is None:
|
if alpha is None:
|
||||||
alpha = 1.
|
alpha = 1.0
|
||||||
if beta is None:
|
if beta is None:
|
||||||
beta = 1.
|
beta = 1.0
|
||||||
if gamma is None:
|
if gamma is None:
|
||||||
gamma = 0.
|
gamma = 0.0
|
||||||
|
|
||||||
# checking consistency
|
# checking consistency
|
||||||
check_consistency(alpha, (float, complex))
|
check_consistency(alpha, (float, complex))
|
||||||
@@ -104,20 +105,20 @@ class AdaptiveActivationFunction(torch.nn.Module):
|
|||||||
# setting not fixed variables as torch.nn.Parameter with gradient
|
# setting not fixed variables as torch.nn.Parameter with gradient
|
||||||
# registering the buffer for the one which are fixed, buffers by
|
# registering the buffer for the one which are fixed, buffers by
|
||||||
# default are saved alongside trainable parameters
|
# default are saved alongside trainable parameters
|
||||||
if 'alpha' not in (fixed or []):
|
if "alpha" not in (fixed or []):
|
||||||
self._alpha = torch.nn.Parameter(alpha, requires_grad=True)
|
self._alpha = torch.nn.Parameter(alpha, requires_grad=True)
|
||||||
else:
|
else:
|
||||||
self.register_buffer('alpha', alpha)
|
self.register_buffer("alpha", alpha)
|
||||||
|
|
||||||
if 'beta' not in (fixed or []):
|
if "beta" not in (fixed or []):
|
||||||
self._beta = torch.nn.Parameter(beta, requires_grad=True)
|
self._beta = torch.nn.Parameter(beta, requires_grad=True)
|
||||||
else:
|
else:
|
||||||
self.register_buffer('beta', beta)
|
self.register_buffer("beta", beta)
|
||||||
|
|
||||||
if 'gamma' not in (fixed or []):
|
if "gamma" not in (fixed or []):
|
||||||
self._gamma = torch.nn.Parameter(gamma, requires_grad=True)
|
self._gamma = torch.nn.Parameter(gamma, requires_grad=True)
|
||||||
else:
|
else:
|
||||||
self.register_buffer('gamma', gamma)
|
self.register_buffer("gamma", gamma)
|
||||||
|
|
||||||
# registering function
|
# registering function
|
||||||
self._func = func
|
self._func = func
|
||||||
@@ -128,21 +129,21 @@ class AdaptiveActivationFunction(torch.nn.Module):
|
|||||||
Applies the function to the input elementwise.
|
Applies the function to the input elementwise.
|
||||||
"""
|
"""
|
||||||
return self.alpha * (self._func(self.beta * x + self.gamma))
|
return self.alpha * (self._func(self.beta * x + self.gamma))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alpha(self):
|
def alpha(self):
|
||||||
"""
|
"""
|
||||||
The alpha variable
|
The alpha variable
|
||||||
"""
|
"""
|
||||||
return self._alpha
|
return self._alpha
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def beta(self):
|
def beta(self):
|
||||||
"""
|
"""
|
||||||
The alpha variable
|
The alpha variable
|
||||||
"""
|
"""
|
||||||
return self._beta
|
return self._beta
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gamma(self):
|
def gamma(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user