🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,21 +1,31 @@
|
||||
__all__ = [
|
||||
'AdaptiveActivationFunctionInterface',
|
||||
'AdaptiveReLU',
|
||||
'AdaptiveSigmoid',
|
||||
'AdaptiveTanh',
|
||||
'AdaptiveSiLU',
|
||||
'AdaptiveMish',
|
||||
'AdaptiveELU',
|
||||
'AdaptiveCELU',
|
||||
'AdaptiveGELU',
|
||||
'AdaptiveSoftmin',
|
||||
'AdaptiveSoftmax',
|
||||
'AdaptiveSIREN',
|
||||
'AdaptiveExp']
|
||||
"AdaptiveActivationFunctionInterface",
|
||||
"AdaptiveReLU",
|
||||
"AdaptiveSigmoid",
|
||||
"AdaptiveTanh",
|
||||
"AdaptiveSiLU",
|
||||
"AdaptiveMish",
|
||||
"AdaptiveELU",
|
||||
"AdaptiveCELU",
|
||||
"AdaptiveGELU",
|
||||
"AdaptiveSoftmin",
|
||||
"AdaptiveSoftmax",
|
||||
"AdaptiveSIREN",
|
||||
"AdaptiveExp",
|
||||
]
|
||||
|
||||
from .adaptive_func import (AdaptiveReLU, AdaptiveSigmoid, AdaptiveTanh,
|
||||
AdaptiveSiLU, AdaptiveMish, AdaptiveELU,
|
||||
AdaptiveCELU, AdaptiveGELU, AdaptiveSoftmin,
|
||||
AdaptiveSoftmax, AdaptiveSIREN, AdaptiveExp)
|
||||
from .adaptive_func import (
|
||||
AdaptiveReLU,
|
||||
AdaptiveSigmoid,
|
||||
AdaptiveTanh,
|
||||
AdaptiveSiLU,
|
||||
AdaptiveMish,
|
||||
AdaptiveELU,
|
||||
AdaptiveCELU,
|
||||
AdaptiveGELU,
|
||||
AdaptiveSoftmin,
|
||||
AdaptiveSoftmax,
|
||||
AdaptiveSIREN,
|
||||
AdaptiveExp,
|
||||
)
|
||||
from .adaptive_func_interface import AdaptiveActivationFunctionInterface
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ class AdaptiveReLU(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.ReLU()
|
||||
@@ -80,6 +81,7 @@ class AdaptiveSigmoid(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Sigmoid()
|
||||
@@ -120,6 +122,7 @@ class AdaptiveTanh(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Tanh()
|
||||
@@ -161,6 +164,7 @@ class AdaptiveSiLU(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.SiLU()
|
||||
@@ -201,6 +205,7 @@ class AdaptiveMish(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Mish()
|
||||
@@ -244,6 +249,7 @@ class AdaptiveELU(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.ELU()
|
||||
@@ -284,10 +290,12 @@ class AdaptiveCELU(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.CELU()
|
||||
|
||||
|
||||
class AdaptiveGELU(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.GELU` activation function.
|
||||
@@ -324,6 +332,7 @@ class AdaptiveGELU(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.GELU()
|
||||
@@ -364,6 +373,7 @@ class AdaptiveSoftmin(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Softmin()
|
||||
@@ -404,10 +414,12 @@ class AdaptiveSoftmax(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Softmax()
|
||||
|
||||
|
||||
class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :obj:`~torch.sin` function.
|
||||
@@ -439,10 +451,12 @@ class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.sin
|
||||
|
||||
|
||||
class AdaptiveExp(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :obj:`~torch.exp` function.
|
||||
@@ -474,15 +488,16 @@ class AdaptiveExp(AdaptiveActivationFunctionInterface):
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, fixed=None):
|
||||
|
||||
# only alpha, and beta parameters (gamma=0 fixed)
|
||||
if fixed is None:
|
||||
fixed = ['gamma']
|
||||
fixed = ["gamma"]
|
||||
else:
|
||||
check_consistency(fixed, str)
|
||||
fixed = list(fixed) + ['gamma']
|
||||
fixed = list(fixed) + ["gamma"]
|
||||
|
||||
# calling super
|
||||
super().__init__(alpha, beta, 0., fixed)
|
||||
super().__init__(alpha, beta, 0.0, fixed)
|
||||
self._func = torch.exp
|
||||
Reference in New Issue
Block a user