🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,21 +1,31 @@
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'AdaptiveActivationFunctionInterface',
|
"AdaptiveActivationFunctionInterface",
|
||||||
'AdaptiveReLU',
|
"AdaptiveReLU",
|
||||||
'AdaptiveSigmoid',
|
"AdaptiveSigmoid",
|
||||||
'AdaptiveTanh',
|
"AdaptiveTanh",
|
||||||
'AdaptiveSiLU',
|
"AdaptiveSiLU",
|
||||||
'AdaptiveMish',
|
"AdaptiveMish",
|
||||||
'AdaptiveELU',
|
"AdaptiveELU",
|
||||||
'AdaptiveCELU',
|
"AdaptiveCELU",
|
||||||
'AdaptiveGELU',
|
"AdaptiveGELU",
|
||||||
'AdaptiveSoftmin',
|
"AdaptiveSoftmin",
|
||||||
'AdaptiveSoftmax',
|
"AdaptiveSoftmax",
|
||||||
'AdaptiveSIREN',
|
"AdaptiveSIREN",
|
||||||
'AdaptiveExp']
|
"AdaptiveExp",
|
||||||
|
]
|
||||||
|
|
||||||
from .adaptive_func import (AdaptiveReLU, AdaptiveSigmoid, AdaptiveTanh,
|
from .adaptive_func import (
|
||||||
AdaptiveSiLU, AdaptiveMish, AdaptiveELU,
|
AdaptiveReLU,
|
||||||
AdaptiveCELU, AdaptiveGELU, AdaptiveSoftmin,
|
AdaptiveSigmoid,
|
||||||
AdaptiveSoftmax, AdaptiveSIREN, AdaptiveExp)
|
AdaptiveTanh,
|
||||||
|
AdaptiveSiLU,
|
||||||
|
AdaptiveMish,
|
||||||
|
AdaptiveELU,
|
||||||
|
AdaptiveCELU,
|
||||||
|
AdaptiveGELU,
|
||||||
|
AdaptiveSoftmin,
|
||||||
|
AdaptiveSoftmax,
|
||||||
|
AdaptiveSIREN,
|
||||||
|
AdaptiveExp,
|
||||||
|
)
|
||||||
from .adaptive_func_interface import AdaptiveActivationFunctionInterface
|
from .adaptive_func_interface import AdaptiveActivationFunctionInterface
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class AdaptiveReLU(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.ReLU()
|
self._func = torch.nn.ReLU()
|
||||||
@@ -80,6 +81,7 @@ class AdaptiveSigmoid(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.Sigmoid()
|
self._func = torch.nn.Sigmoid()
|
||||||
@@ -120,6 +122,7 @@ class AdaptiveTanh(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.Tanh()
|
self._func = torch.nn.Tanh()
|
||||||
@@ -161,6 +164,7 @@ class AdaptiveSiLU(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.SiLU()
|
self._func = torch.nn.SiLU()
|
||||||
@@ -201,6 +205,7 @@ class AdaptiveMish(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.Mish()
|
self._func = torch.nn.Mish()
|
||||||
@@ -244,6 +249,7 @@ class AdaptiveELU(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.ELU()
|
self._func = torch.nn.ELU()
|
||||||
@@ -284,10 +290,12 @@ class AdaptiveCELU(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.CELU()
|
self._func = torch.nn.CELU()
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveGELU(AdaptiveActivationFunctionInterface):
|
class AdaptiveGELU(AdaptiveActivationFunctionInterface):
|
||||||
r"""
|
r"""
|
||||||
Adaptive trainable :class:`~torch.nn.GELU` activation function.
|
Adaptive trainable :class:`~torch.nn.GELU` activation function.
|
||||||
@@ -324,6 +332,7 @@ class AdaptiveGELU(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.GELU()
|
self._func = torch.nn.GELU()
|
||||||
@@ -364,6 +373,7 @@ class AdaptiveSoftmin(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.Softmin()
|
self._func = torch.nn.Softmin()
|
||||||
@@ -404,10 +414,12 @@ class AdaptiveSoftmax(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.nn.Softmax()
|
self._func = torch.nn.Softmax()
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
|
class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
|
||||||
r"""
|
r"""
|
||||||
Adaptive trainable :obj:`~torch.sin` function.
|
Adaptive trainable :obj:`~torch.sin` function.
|
||||||
@@ -439,10 +451,12 @@ class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
super().__init__(alpha, beta, gamma, fixed)
|
super().__init__(alpha, beta, gamma, fixed)
|
||||||
self._func = torch.sin
|
self._func = torch.sin
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveExp(AdaptiveActivationFunctionInterface):
|
class AdaptiveExp(AdaptiveActivationFunctionInterface):
|
||||||
r"""
|
r"""
|
||||||
Adaptive trainable :obj:`~torch.exp` function.
|
Adaptive trainable :obj:`~torch.exp` function.
|
||||||
@@ -474,15 +488,16 @@ class AdaptiveExp(AdaptiveActivationFunctionInterface):
|
|||||||
DOI: `JCP 10.1016
|
DOI: `JCP 10.1016
|
||||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, fixed=None):
|
def __init__(self, alpha=None, beta=None, fixed=None):
|
||||||
|
|
||||||
# only alpha, and beta parameters (gamma=0 fixed)
|
# only alpha, and beta parameters (gamma=0 fixed)
|
||||||
if fixed is None:
|
if fixed is None:
|
||||||
fixed = ['gamma']
|
fixed = ["gamma"]
|
||||||
else:
|
else:
|
||||||
check_consistency(fixed, str)
|
check_consistency(fixed, str)
|
||||||
fixed = list(fixed) + ['gamma']
|
fixed = list(fixed) + ["gamma"]
|
||||||
|
|
||||||
# calling super
|
# calling super
|
||||||
super().__init__(alpha, beta, 0., fixed)
|
super().__init__(alpha, beta, 0.0, fixed)
|
||||||
self._func = torch.exp
|
self._func = torch.exp
|
||||||
Reference in New Issue
Block a user