🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-04-08 15:18:17 +00:00
committed by Nicola Demo
parent 4f5d9559b2
commit e2d000ab22
4 changed files with 72 additions and 47 deletions

View File

@@ -1,21 +1,31 @@
__all__ = [ __all__ = [
'AdaptiveActivationFunctionInterface', "AdaptiveActivationFunctionInterface",
'AdaptiveReLU', "AdaptiveReLU",
'AdaptiveSigmoid', "AdaptiveSigmoid",
'AdaptiveTanh', "AdaptiveTanh",
'AdaptiveSiLU', "AdaptiveSiLU",
'AdaptiveMish', "AdaptiveMish",
'AdaptiveELU', "AdaptiveELU",
'AdaptiveCELU', "AdaptiveCELU",
'AdaptiveGELU', "AdaptiveGELU",
'AdaptiveSoftmin', "AdaptiveSoftmin",
'AdaptiveSoftmax', "AdaptiveSoftmax",
'AdaptiveSIREN', "AdaptiveSIREN",
'AdaptiveExp'] "AdaptiveExp",
]
from .adaptive_func import (AdaptiveReLU, AdaptiveSigmoid, AdaptiveTanh, from .adaptive_func import (
AdaptiveSiLU, AdaptiveMish, AdaptiveELU, AdaptiveReLU,
AdaptiveCELU, AdaptiveGELU, AdaptiveSoftmin, AdaptiveSigmoid,
AdaptiveSoftmax, AdaptiveSIREN, AdaptiveExp) AdaptiveTanh,
AdaptiveSiLU,
AdaptiveMish,
AdaptiveELU,
AdaptiveCELU,
AdaptiveGELU,
AdaptiveSoftmin,
AdaptiveSoftmax,
AdaptiveSIREN,
AdaptiveExp,
)
from .adaptive_func_interface import AdaptiveActivationFunctionInterface from .adaptive_func_interface import AdaptiveActivationFunctionInterface

View File

@@ -40,6 +40,7 @@ class AdaptiveReLU(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.ReLU() self._func = torch.nn.ReLU()
@@ -80,6 +81,7 @@ class AdaptiveSigmoid(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.Sigmoid() self._func = torch.nn.Sigmoid()
@@ -120,6 +122,7 @@ class AdaptiveTanh(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.Tanh() self._func = torch.nn.Tanh()
@@ -161,6 +164,7 @@ class AdaptiveSiLU(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.SiLU() self._func = torch.nn.SiLU()
@@ -201,6 +205,7 @@ class AdaptiveMish(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.Mish() self._func = torch.nn.Mish()
@@ -244,6 +249,7 @@ class AdaptiveELU(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.ELU() self._func = torch.nn.ELU()
@@ -284,10 +290,12 @@ class AdaptiveCELU(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.CELU() self._func = torch.nn.CELU()
class AdaptiveGELU(AdaptiveActivationFunctionInterface): class AdaptiveGELU(AdaptiveActivationFunctionInterface):
r""" r"""
Adaptive trainable :class:`~torch.nn.GELU` activation function. Adaptive trainable :class:`~torch.nn.GELU` activation function.
@@ -324,6 +332,7 @@ class AdaptiveGELU(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.GELU() self._func = torch.nn.GELU()
@@ -364,6 +373,7 @@ class AdaptiveSoftmin(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.Softmin() self._func = torch.nn.Softmin()
@@ -404,10 +414,12 @@ class AdaptiveSoftmax(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.nn.Softmax() self._func = torch.nn.Softmax()
class AdaptiveSIREN(AdaptiveActivationFunctionInterface): class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
r""" r"""
Adaptive trainable :obj:`~torch.sin` function. Adaptive trainable :obj:`~torch.sin` function.
@@ -439,10 +451,12 @@ class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
super().__init__(alpha, beta, gamma, fixed) super().__init__(alpha, beta, gamma, fixed)
self._func = torch.sin self._func = torch.sin
class AdaptiveExp(AdaptiveActivationFunctionInterface): class AdaptiveExp(AdaptiveActivationFunctionInterface):
r""" r"""
Adaptive trainable :obj:`~torch.exp` function. Adaptive trainable :obj:`~torch.exp` function.
@@ -474,15 +488,16 @@ class AdaptiveExp(AdaptiveActivationFunctionInterface):
DOI: `JCP 10.1016 DOI: `JCP 10.1016
<https://doi.org/10.1016/j.jcp.2019.109136>`_. <https://doi.org/10.1016/j.jcp.2019.109136>`_.
""" """
def __init__(self, alpha=None, beta=None, fixed=None): def __init__(self, alpha=None, beta=None, fixed=None):
# only alpha, and beta parameters (gamma=0 fixed) # only alpha, and beta parameters (gamma=0 fixed)
if fixed is None: if fixed is None:
fixed = ['gamma'] fixed = ["gamma"]
else: else:
check_consistency(fixed, str) check_consistency(fixed, str)
fixed = list(fixed) + ['gamma'] fixed = list(fixed) + ["gamma"]
# calling super # calling super
super().__init__(alpha, beta, 0., fixed) super().__init__(alpha, beta, 0.0, fixed)
self._func = torch.exp self._func = torch.exp