Renaming
* solvers -> solver * adaptive_functions -> adaptive_function * callbacks -> callback * operators -> operator * pinns -> physics_informed_solver * layers -> block
This commit is contained in:
committed by
Nicola Demo
parent
810d215ca0
commit
df673cad4e
31
pina/adaptive_function/__init__.py
Normal file
31
pina/adaptive_function/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
__all__ = [
|
||||
"AdaptiveActivationFunctionInterface",
|
||||
"AdaptiveReLU",
|
||||
"AdaptiveSigmoid",
|
||||
"AdaptiveTanh",
|
||||
"AdaptiveSiLU",
|
||||
"AdaptiveMish",
|
||||
"AdaptiveELU",
|
||||
"AdaptiveCELU",
|
||||
"AdaptiveGELU",
|
||||
"AdaptiveSoftmin",
|
||||
"AdaptiveSoftmax",
|
||||
"AdaptiveSIREN",
|
||||
"AdaptiveExp",
|
||||
]
|
||||
|
||||
from .adaptive_func import (
|
||||
AdaptiveReLU,
|
||||
AdaptiveSigmoid,
|
||||
AdaptiveTanh,
|
||||
AdaptiveSiLU,
|
||||
AdaptiveMish,
|
||||
AdaptiveELU,
|
||||
AdaptiveCELU,
|
||||
AdaptiveGELU,
|
||||
AdaptiveSoftmin,
|
||||
AdaptiveSoftmax,
|
||||
AdaptiveSIREN,
|
||||
AdaptiveExp,
|
||||
)
|
||||
from .adaptive_func_interface import AdaptiveActivationFunctionInterface
|
||||
503
pina/adaptive_function/adaptive_func.py
Normal file
503
pina/adaptive_function/adaptive_func.py
Normal file
@@ -0,0 +1,503 @@
|
||||
""" Module for adaptive functions. """
|
||||
|
||||
import torch
|
||||
from ..utils import check_consistency
|
||||
from .adaptive_func_interface import AdaptiveActivationFunctionInterface
|
||||
|
||||
|
||||
class AdaptiveReLU(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.ReLU` activation function.
|
||||
|
||||
Given the function :math:`\text{ReLU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{ReLU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{ReLU}_{\text{adaptive}}({x}) = \alpha\,\text{ReLU}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
ReLU function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{ReLU}(x) = \max(0, x)
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.ReLU()
|
||||
|
||||
|
||||
class AdaptiveSigmoid(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.Sigmoid` activation function.
|
||||
|
||||
Given the function :math:`\text{Sigmoid}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{Sigmoid}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Sigmoid}_{\text{adaptive}}({x}) = \alpha\,\text{Sigmoid}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
Sigmoid function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Sigmoid()
|
||||
|
||||
|
||||
class AdaptiveTanh(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.Tanh` activation function.
|
||||
|
||||
Given the function :math:`\text{Tanh}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{Tanh}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Tanh}_{\text{adaptive}}({x}) = \alpha\,\text{Tanh}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
Tanh function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Tanh}(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Tanh()
|
||||
|
||||
|
||||
class AdaptiveSiLU(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.SiLU` activation function.
|
||||
|
||||
Given the function :math:`\text{SiLU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{SiLU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{SiLU}_{\text{adaptive}}({x}) = \alpha\,\text{SiLU}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
SiLU function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{SiLU}(x) = x * \sigma(x), \text{where }\sigma(x)
|
||||
\text{ is the logistic sigmoid.}
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.SiLU()
|
||||
|
||||
|
||||
class AdaptiveMish(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.Mish` activation function.
|
||||
|
||||
Given the function :math:`\text{Mish}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{Mish}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Mish}_{\text{adaptive}}({x}) = \alpha\,\text{Mish}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
Mish function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Mish}(x) = x * \text{Tanh}(x)
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Mish()
|
||||
|
||||
|
||||
class AdaptiveELU(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.ELU` activation function.
|
||||
|
||||
Given the function :math:`\text{ELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{ELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{ELU}_{\text{adaptive}}({x}) = \alpha\,\text{ELU}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
ELU function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{ELU}(x) = \begin{cases}
|
||||
x, & \text{ if }x > 0\\
|
||||
\exp(x) - 1, & \text{ if }x \leq 0
|
||||
\end{cases}
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.ELU()
|
||||
|
||||
|
||||
class AdaptiveCELU(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.CELU` activation function.
|
||||
|
||||
Given the function :math:`\text{CELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{CELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{CELU}_{\text{adaptive}}({x}) = \alpha\,\text{CELU}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
CELU function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.CELU()
|
||||
|
||||
|
||||
class AdaptiveGELU(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.GELU` activation function.
|
||||
|
||||
Given the function :math:`\text{GELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{GELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{GELU}_{\text{adaptive}}({x}) = \alpha\,\text{GELU}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
GELU function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
|
||||
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.GELU()
|
||||
|
||||
|
||||
class AdaptiveSoftmin(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.Softmin` activation function.
|
||||
|
||||
Given the function :math:`\text{Softmin}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{Softmin}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Softmin}_{\text{adaptive}}({x}) = \alpha\,\text{Softmin}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
Softmin function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Softmin()
|
||||
|
||||
|
||||
class AdaptiveSoftmax(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :class:`~torch.nn.Softmax` activation function.
|
||||
|
||||
Given the function :math:`\text{Softmax}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{Softmax}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Softmax}_{\text{adaptive}}({x}) = \alpha\,\text{Softmax}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
|
||||
Softmax function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.nn.Softmax()
|
||||
|
||||
|
||||
class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :obj:`~torch.sin` function.
|
||||
|
||||
Given the function :math:`\text{sin}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{sin}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{sin}_{\text{adaptive}}({x}) = \alpha\,\text{sin}(\beta{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
super().__init__(alpha, beta, gamma, fixed)
|
||||
self._func = torch.sin
|
||||
|
||||
|
||||
class AdaptiveExp(AdaptiveActivationFunctionInterface):
|
||||
r"""
|
||||
Adaptive trainable :obj:`~torch.exp` function.
|
||||
|
||||
Given the function :math:`\text{exp}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
|
||||
the adaptive function
|
||||
:math:`\text{exp}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
\text{exp}_{\text{adaptive}}({x}) = \alpha\,\text{exp}(\beta{x}),
|
||||
|
||||
where :math:`\alpha,\,\beta` are trainable parameters.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, fixed=None):
|
||||
|
||||
# only alpha, and beta parameters (gamma=0 fixed)
|
||||
if fixed is None:
|
||||
fixed = ["gamma"]
|
||||
else:
|
||||
check_consistency(fixed, str)
|
||||
fixed = list(fixed) + ["gamma"]
|
||||
|
||||
# calling super
|
||||
super().__init__(alpha, beta, 0.0, fixed)
|
||||
self._func = torch.exp
|
||||
146
pina/adaptive_function/adaptive_func_interface.py
Normal file
146
pina/adaptive_function/adaptive_func_interface.py
Normal file
@@ -0,0 +1,146 @@
|
||||
""" Module for adaptive functions. """
|
||||
|
||||
import torch
|
||||
|
||||
from pina.utils import check_consistency
|
||||
from abc import ABCMeta
|
||||
|
||||
|
||||
class AdaptiveActivationFunctionInterface(torch.nn.Module, metaclass=ABCMeta):
|
||||
r"""
|
||||
The
|
||||
:class:`~pina.adaptive_function.adaptive_func_interface.AdaptiveActivationFunctionInterface`
|
||||
class makes a :class:`torch.nn.Module` activation function into an adaptive
|
||||
trainable activation function. If one wants to create an adpative activation
|
||||
function, this class must be use as base class.
|
||||
|
||||
Given a function :math:`f:\mathbb{R}^n\rightarrow\mathbb{R}^m`, the adaptive
|
||||
function :math:`f_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^m`
|
||||
is defined as:
|
||||
|
||||
.. math::
|
||||
f_{\text{adaptive}}(\mathbf{x}) = \alpha\,f(\beta\mathbf{x}+\gamma),
|
||||
|
||||
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||
*A continuum among logarithmic, linear, and exponential functions,
|
||||
and its potential to improve generalization in neural networks.*
|
||||
2015 7th international joint conference on knowledge discovery,
|
||||
knowledge engineering and knowledge management (IC3K).
|
||||
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||
<https://arxiv.org/abs/1602.01321>`_.
|
||||
|
||||
Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
|
||||
activation functions accelerate convergence in deep and
|
||||
physics-informed neural networks*. Journal of
|
||||
Computational Physics 404 (2020): 109136.
|
||||
DOI: `JCP 10.1016
|
||||
<https://doi.org/10.1016/j.jcp.2019.109136>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
|
||||
"""
|
||||
Initializes the Adaptive Function.
|
||||
|
||||
:param float | complex alpha: Scaling parameter alpha.
|
||||
Defaults to ``None``. When ``None`` is passed,
|
||||
the variable is initialized to 1.
|
||||
:param float | complex beta: Scaling parameter beta.
|
||||
Defaults to ``None``. When ``None`` is passed,
|
||||
the variable is initialized to 1.
|
||||
:param float | complex gamma: Shifting parameter gamma.
|
||||
Defaults to ``None``. When ``None`` is passed,
|
||||
the variable is initialized to 1.
|
||||
:param list fixed: List of parameters to fix during training,
|
||||
i.e. not optimized (``requires_grad`` set to ``False``).
|
||||
Options are ``alpha``, ``beta``, ``gamma``. Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# see if there are fixed variables
|
||||
if fixed is not None:
|
||||
check_consistency(fixed, str)
|
||||
if not all(key in ["alpha", "beta", "gamma"] for key in fixed):
|
||||
raise TypeError(
|
||||
"Fixed keys must be in [`alpha`, `beta`, `gamma`]."
|
||||
)
|
||||
|
||||
# initialize alpha, beta, gamma if they are None
|
||||
if alpha is None:
|
||||
alpha = 1.0
|
||||
if beta is None:
|
||||
beta = 1.0
|
||||
if gamma is None:
|
||||
gamma = 0.0
|
||||
|
||||
# checking consistency
|
||||
check_consistency(alpha, (float, complex))
|
||||
check_consistency(beta, (float, complex))
|
||||
check_consistency(gamma, (float, complex))
|
||||
|
||||
# registering as tensors
|
||||
alpha = torch.tensor(alpha, requires_grad=False)
|
||||
beta = torch.tensor(beta, requires_grad=False)
|
||||
gamma = torch.tensor(gamma, requires_grad=False)
|
||||
|
||||
# setting not fixed variables as torch.nn.Parameter with gradient
|
||||
# registering the buffer for the one which are fixed, buffers by
|
||||
# default are saved alongside trainable parameters
|
||||
if "alpha" not in (fixed or []):
|
||||
self._alpha = torch.nn.Parameter(alpha, requires_grad=True)
|
||||
else:
|
||||
self.register_buffer("alpha", alpha)
|
||||
|
||||
if "beta" not in (fixed or []):
|
||||
self._beta = torch.nn.Parameter(beta, requires_grad=True)
|
||||
else:
|
||||
self.register_buffer("beta", beta)
|
||||
|
||||
if "gamma" not in (fixed or []):
|
||||
self._gamma = torch.nn.Parameter(gamma, requires_grad=True)
|
||||
else:
|
||||
self.register_buffer("gamma", gamma)
|
||||
|
||||
# storing the activation
|
||||
self._func = None
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Define the computation performed at every call.
|
||||
The function to the input elementwise.
|
||||
|
||||
:param x: The input tensor to evaluate the activation function.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
"""
|
||||
return self.alpha * (self._func(self.beta * x + self.gamma))
|
||||
|
||||
@property
|
||||
def alpha(self):
|
||||
"""
|
||||
The alpha variable.
|
||||
"""
|
||||
return self._alpha
|
||||
|
||||
@property
|
||||
def beta(self):
|
||||
"""
|
||||
The beta variable.
|
||||
"""
|
||||
return self._beta
|
||||
|
||||
@property
|
||||
def gamma(self):
|
||||
"""
|
||||
The gamma variable.
|
||||
"""
|
||||
return self._gamma
|
||||
|
||||
@property
|
||||
def func(self):
|
||||
"""
|
||||
The callable activation function.
|
||||
"""
|
||||
return self._func
|
||||
Reference in New Issue
Block a user