Adaptive Functions (#272)
* adaptive function improvement --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
@@ -72,6 +72,7 @@ Layers
|
|||||||
Continuous convolution <layers/convolution.rst>
|
Continuous convolution <layers/convolution.rst>
|
||||||
Proper Orthogonal Decomposition <layers/pod.rst>
|
Proper Orthogonal Decomposition <layers/pod.rst>
|
||||||
Periodic Boundary Condition embeddings <layers/embedding.rst>
|
Periodic Boundary Condition embeddings <layers/embedding.rst>
|
||||||
|
Adpative Activation Function <layers/adaptive_func.rst>
|
||||||
|
|
||||||
Equations and Operators
|
Equations and Operators
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|||||||
7
docs/source/_rst/layers/adaptive_func.rst
Normal file
7
docs/source/_rst/layers/adaptive_func.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
AdaptiveActivationFunction
|
||||||
|
=============================
|
||||||
|
.. currentmodule:: pina.model.layers.adaptive_func
|
||||||
|
|
||||||
|
.. autoclass:: AdaptiveActivationFunction
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from .adaptive_tanh import AdaptiveTanh
|
|
||||||
from .adaptive_sin import AdaptiveSin
|
|
||||||
from .adaptive_cos import AdaptiveCos
|
|
||||||
from .adaptive_linear import AdaptiveLinear
|
|
||||||
from .adaptive_square import AdaptiveSquare
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveCos(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of soft exponential activation.
|
|
||||||
Shape:
|
|
||||||
- Input: (N, *) where * means, any number of additional
|
|
||||||
dimensions
|
|
||||||
- Output: (N, *), same shape as the input
|
|
||||||
Parameters:
|
|
||||||
- alpha - trainable parameter
|
|
||||||
References:
|
|
||||||
- See related paper:
|
|
||||||
https://arxiv.org/pdf/1602.01321.pdf
|
|
||||||
Examples:
|
|
||||||
>>> a1 = soft_exponential(256)
|
|
||||||
>>> x = torch.randn(256)
|
|
||||||
>>> x = a1(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, alpha=None):
|
|
||||||
"""
|
|
||||||
Initialization.
|
|
||||||
INPUT:
|
|
||||||
- in_features: shape of the input
|
|
||||||
- aplha: trainable parameter
|
|
||||||
aplha is initialized with zero value by default
|
|
||||||
"""
|
|
||||||
super(AdaptiveCos, self).__init__()
|
|
||||||
# self.in_features = in_features
|
|
||||||
|
|
||||||
# initialize alpha
|
|
||||||
if alpha == None:
|
|
||||||
self.alpha = Parameter(
|
|
||||||
torch.tensor(1.0)
|
|
||||||
) # create a tensor out of alpha
|
|
||||||
else:
|
|
||||||
self.alpha = Parameter(
|
|
||||||
torch.tensor(alpha)
|
|
||||||
) # create a tensor out of alpha
|
|
||||||
self.alpha.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.scale = Parameter(torch.tensor(1.0))
|
|
||||||
self.scale.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.translate = Parameter(torch.tensor(0.0))
|
|
||||||
self.translate.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the function.
|
|
||||||
Applies the function to the input elementwise.
|
|
||||||
"""
|
|
||||||
return self.scale * (torch.cos(self.alpha * x + self.translate))
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveExp(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of soft exponential activation.
|
|
||||||
Shape:
|
|
||||||
- Input: (N, *) where * means, any number of additional
|
|
||||||
dimensions
|
|
||||||
- Output: (N, *), same shape as the input
|
|
||||||
Parameters:
|
|
||||||
- alpha - trainable parameter
|
|
||||||
References:
|
|
||||||
- See related paper:
|
|
||||||
https://arxiv.org/pdf/1602.01321.pdf
|
|
||||||
Examples:
|
|
||||||
>>> a1 = soft_exponential(256)
|
|
||||||
>>> x = torch.randn(256)
|
|
||||||
>>> x = a1(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initialization.
|
|
||||||
INPUT:
|
|
||||||
- in_features: shape of the input
|
|
||||||
- aplha: trainable parameter
|
|
||||||
aplha is initialized with zero value by default
|
|
||||||
"""
|
|
||||||
super(AdaptiveExp, self).__init__()
|
|
||||||
|
|
||||||
self.scale = Parameter(
|
|
||||||
torch.normal(torch.tensor(1.0), torch.tensor(0.1))
|
|
||||||
) # create a tensor out of alpha
|
|
||||||
self.scale.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.alpha = Parameter(
|
|
||||||
torch.normal(torch.tensor(1.0), torch.tensor(0.1))
|
|
||||||
) # create a tensor out of alpha
|
|
||||||
self.alpha.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.translate = Parameter(
|
|
||||||
torch.normal(torch.tensor(0.0), torch.tensor(0.1))
|
|
||||||
) # create a tensor out of alpha
|
|
||||||
self.translate.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the function.
|
|
||||||
Applies the function to the input elementwise.
|
|
||||||
"""
|
|
||||||
return self.scale * (x + self.translate)
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
""" Implementation of adaptive linear layer. """
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveLinear(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of soft exponential activation.
|
|
||||||
Shape:
|
|
||||||
- Input: (N, *) where * means, any number of additional
|
|
||||||
dimensions
|
|
||||||
- Output: (N, *), same shape as the input
|
|
||||||
Parameters:
|
|
||||||
- alpha - trainable parameter
|
|
||||||
References:
|
|
||||||
- See related paper:
|
|
||||||
https://arxiv.org/pdf/1602.01321.pdf
|
|
||||||
Examples:
|
|
||||||
>>> a1 = soft_exponential(256)
|
|
||||||
>>> x = torch.randn(256)
|
|
||||||
>>> x = a1(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initialization.
|
|
||||||
INPUT:
|
|
||||||
- in_features: shape of the input
|
|
||||||
- aplha: trainable parameter
|
|
||||||
aplha is initialized with zero value by default
|
|
||||||
"""
|
|
||||||
super(AdaptiveLinear, self).__init__()
|
|
||||||
|
|
||||||
self.scale = Parameter(torch.tensor(1.0))
|
|
||||||
self.scale.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.translate = Parameter(torch.tensor(0.0))
|
|
||||||
self.translate.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the function.
|
|
||||||
Applies the function to the input elementwise.
|
|
||||||
"""
|
|
||||||
return self.scale * (x + self.translate)
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveReLU(torch.nn.Module, Parameter):
|
|
||||||
"""
|
|
||||||
Implementation of soft exponential activation.
|
|
||||||
Shape:
|
|
||||||
- Input: (N, *) where * means, any number of additional
|
|
||||||
dimensions
|
|
||||||
- Output: (N, *), same shape as the input
|
|
||||||
Parameters:
|
|
||||||
- alpha - trainable parameter
|
|
||||||
References:
|
|
||||||
- See related paper:
|
|
||||||
https://arxiv.org/pdf/1602.01321.pdf
|
|
||||||
Examples:
|
|
||||||
>>> a1 = soft_exponential(256)
|
|
||||||
>>> x = torch.randn(256)
|
|
||||||
>>> x = a1(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initialization.
|
|
||||||
INPUT:
|
|
||||||
- in_features: shape of the input
|
|
||||||
- aplha: trainable parameter
|
|
||||||
aplha is initialized with zero value by default
|
|
||||||
"""
|
|
||||||
super(AdaptiveReLU, self).__init__()
|
|
||||||
|
|
||||||
self.scale = Parameter(torch.rand(1))
|
|
||||||
self.scale.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.translate = Parameter(torch.rand(1))
|
|
||||||
self.translate.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the function.
|
|
||||||
Applies the function to the input elementwise.
|
|
||||||
"""
|
|
||||||
# x += self.translate
|
|
||||||
return torch.relu(x + self.translate) * self.scale
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveSin(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of soft exponential activation.
|
|
||||||
Shape:
|
|
||||||
- Input: (N, *) where * means, any number of additional
|
|
||||||
dimensions
|
|
||||||
- Output: (N, *), same shape as the input
|
|
||||||
Parameters:
|
|
||||||
- alpha - trainable parameter
|
|
||||||
References:
|
|
||||||
- See related paper:
|
|
||||||
https://arxiv.org/pdf/1602.01321.pdf
|
|
||||||
Examples:
|
|
||||||
>>> a1 = soft_exponential(256)
|
|
||||||
>>> x = torch.randn(256)
|
|
||||||
>>> x = a1(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, alpha=None):
|
|
||||||
"""
|
|
||||||
Initialization.
|
|
||||||
INPUT:
|
|
||||||
- in_features: shape of the input
|
|
||||||
- aplha: trainable parameter
|
|
||||||
aplha is initialized with zero value by default
|
|
||||||
"""
|
|
||||||
super(AdaptiveSin, self).__init__()
|
|
||||||
|
|
||||||
# initialize alpha
|
|
||||||
self.alpha = Parameter(
|
|
||||||
torch.normal(torch.tensor(1.0), torch.tensor(0.1))
|
|
||||||
) # create a tensor out of alpha
|
|
||||||
self.alpha.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.scale = Parameter(
|
|
||||||
torch.normal(torch.tensor(1.0), torch.tensor(0.1))
|
|
||||||
)
|
|
||||||
self.scale.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.translate = Parameter(
|
|
||||||
torch.normal(torch.tensor(0.0), torch.tensor(0.1))
|
|
||||||
)
|
|
||||||
self.translate.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the function.
|
|
||||||
Applies the function to the input elementwise.
|
|
||||||
"""
|
|
||||||
return self.scale * (torch.sin(self.alpha * x + self.translate))
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveSoftplus(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of soft exponential activation.
|
|
||||||
Shape:
|
|
||||||
- Input: (N, *) where * means, any number of additional
|
|
||||||
dimensions
|
|
||||||
- Output: (N, *), same shape as the input
|
|
||||||
Parameters:
|
|
||||||
- alpha - trainable parameter
|
|
||||||
References:
|
|
||||||
- See related paper:
|
|
||||||
https://arxiv.org/pdf/1602.01321.pdf
|
|
||||||
Examples:
|
|
||||||
>>> a1 = soft_exponential(256)
|
|
||||||
>>> x = torch.randn(256)
|
|
||||||
>>> x = a1(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initialization.
|
|
||||||
INPUT:
|
|
||||||
- in_features: shape of the input
|
|
||||||
- aplha: trainable parameter
|
|
||||||
aplha is initialized with zero value by default
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.soft = torch.nn.Softplus()
|
|
||||||
|
|
||||||
self.scale = Parameter(torch.rand(1))
|
|
||||||
self.scale.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the function.
|
|
||||||
Applies the function to the input elementwise.
|
|
||||||
"""
|
|
||||||
# x += self.translate
|
|
||||||
return self.soft(x) * self.scale
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveSquare(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of soft exponential activation.
|
|
||||||
Shape:
|
|
||||||
- Input: (N, *) where * means, any number of additional
|
|
||||||
dimensions
|
|
||||||
- Output: (N, *), same shape as the input
|
|
||||||
Parameters:
|
|
||||||
- alpha - trainable parameter
|
|
||||||
References:
|
|
||||||
- See related paper:
|
|
||||||
https://arxiv.org/pdf/1602.01321.pdf
|
|
||||||
Examples:
|
|
||||||
>>> a1 = soft_exponential(256)
|
|
||||||
>>> x = torch.randn(256)
|
|
||||||
>>> x = a1(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, alpha=None):
|
|
||||||
"""
|
|
||||||
Initialization.
|
|
||||||
INPUT:
|
|
||||||
- in_features: shape of the input
|
|
||||||
- aplha: trainable parameter
|
|
||||||
aplha is initialized with zero value by default
|
|
||||||
"""
|
|
||||||
super(AdaptiveSquare, self).__init__()
|
|
||||||
|
|
||||||
self.scale = Parameter(torch.tensor(1.0))
|
|
||||||
self.scale.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.translate = Parameter(torch.tensor(0.0))
|
|
||||||
self.translate.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the function.
|
|
||||||
Applies the function to the input elementwise.
|
|
||||||
"""
|
|
||||||
return self.scale * (x + self.translate) ** 2
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveTanh(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of soft exponential activation.
|
|
||||||
Shape:
|
|
||||||
- Input: (N, *) where * means, any number of additional
|
|
||||||
dimensions
|
|
||||||
- Output: (N, *), same shape as the input
|
|
||||||
Parameters:
|
|
||||||
- alpha - trainable parameter
|
|
||||||
References:
|
|
||||||
- See related paper:
|
|
||||||
https://arxiv.org/pdf/1602.01321.pdf
|
|
||||||
Examples:
|
|
||||||
>>> a1 = soft_exponential(256)
|
|
||||||
>>> x = torch.randn(256)
|
|
||||||
>>> x = a1(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, alpha=None):
|
|
||||||
"""
|
|
||||||
Initialization.
|
|
||||||
INPUT:
|
|
||||||
- in_features: shape of the input
|
|
||||||
- aplha: trainable parameter
|
|
||||||
aplha is initialized with zero value by default
|
|
||||||
"""
|
|
||||||
super(AdaptiveTanh, self).__init__()
|
|
||||||
# self.in_features = in_features
|
|
||||||
|
|
||||||
# initialize alpha
|
|
||||||
if alpha == None:
|
|
||||||
self.alpha = Parameter(
|
|
||||||
torch.tensor(1.0)
|
|
||||||
) # create a tensor out of alpha
|
|
||||||
else:
|
|
||||||
self.alpha = Parameter(
|
|
||||||
torch.tensor(alpha)
|
|
||||||
) # create a tensor out of alpha
|
|
||||||
|
|
||||||
self.alpha.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.scale = Parameter(torch.tensor(1.0))
|
|
||||||
self.scale.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
self.translate = Parameter(torch.tensor(0.0))
|
|
||||||
self.translate.requiresGrad = True # set requiresGrad to true!
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the function.
|
|
||||||
Applies the function to the input elementwise.
|
|
||||||
"""
|
|
||||||
x += self.translate
|
|
||||||
return (
|
|
||||||
self.scale
|
|
||||||
* (torch.exp(self.alpha * x) - torch.exp(-self.alpha * x))
|
|
||||||
/ (torch.exp(self.alpha * x) + torch.exp(-self.alpha * x))
|
|
||||||
)
|
|
||||||
@@ -11,6 +11,7 @@ __all__ = [
|
|||||||
"PODBlock",
|
"PODBlock",
|
||||||
"PeriodicBoundaryEmbedding",
|
"PeriodicBoundaryEmbedding",
|
||||||
"AVNOBlock",
|
"AVNOBlock",
|
||||||
|
"AdaptiveActivationFunction"
|
||||||
]
|
]
|
||||||
|
|
||||||
from .convolution_2d import ContinuousConvBlock
|
from .convolution_2d import ContinuousConvBlock
|
||||||
@@ -24,3 +25,4 @@ from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
|||||||
from .pod import PODBlock
|
from .pod import PODBlock
|
||||||
from .embedding import PeriodicBoundaryEmbedding
|
from .embedding import PeriodicBoundaryEmbedding
|
||||||
from .avno_layer import AVNOBlock
|
from .avno_layer import AVNOBlock
|
||||||
|
from .adaptive_func import AdaptiveActivationFunction
|
||||||
|
|||||||
151
pina/model/layers/adaptive_func.py
Normal file
151
pina/model/layers/adaptive_func.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
""" Module for adaptive functions. """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveActivationFunction(torch.nn.Module):
|
||||||
|
r"""
|
||||||
|
The :class:`~pina.model.layers.adaptive_func.AdaptiveActivationFunction`
|
||||||
|
class makes a :class:`torch.nn.Module` activation function into an adaptive
|
||||||
|
trainable activation function.
|
||||||
|
|
||||||
|
Given a function :math:`f:\mathbb{R}^n\rightarrow\mathbb{R}^m`, the adaptive
|
||||||
|
function :math:`f_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^m`
|
||||||
|
is defined as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
f_{\text{adaptive}}(\mathbf{x}) = \alpha\,f(\beta\mathbf{x}+\gamma),
|
||||||
|
|
||||||
|
where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters.
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
>>> import torch
|
||||||
|
>>> from pina.model.layers import AdaptiveActivationFunction
|
||||||
|
>>>
|
||||||
|
>>> # simple adaptive function with all trainable parameters
|
||||||
|
>>> AdaptiveTanh = AdaptiveActivationFunction(torch.nn.Tanh())
|
||||||
|
>>> AdaptiveTanh(torch.rand(3))
|
||||||
|
tensor([0.1084, 0.3931, 0.7294], grad_fn=<MulBackward0>)
|
||||||
|
>>> AdaptiveTanh.alpha
|
||||||
|
Parameter containing:
|
||||||
|
tensor(1., requires_grad=True)
|
||||||
|
>>>
|
||||||
|
>>> # simple adaptive function with trainable parameters fixed alpha
|
||||||
|
>>> AdaptiveTanh = AdaptiveActivationFunction(torch.nn.Tanh(),
|
||||||
|
... fixed=['alpha'])
|
||||||
|
>>> AdaptiveTanh.alpha
|
||||||
|
tensor(1.)
|
||||||
|
>>> AdaptiveTanh.beta
|
||||||
|
Parameter containing:
|
||||||
|
tensor(1., requires_grad=True)
|
||||||
|
>>>
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Godfrey, Luke B., and Michael S. Gashler.
|
||||||
|
*A continuum among logarithmic, linear, and exponential functions,
|
||||||
|
and its potential to improve generalization in neural networks.*
|
||||||
|
2015 7th international joint conference on knowledge discovery,
|
||||||
|
knowledge engineering and knowledge management (IC3K).
|
||||||
|
Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
|
||||||
|
<https://arxiv.org/abs/1602.01321>`_.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func, alpha=None, beta=None, gamma=None, fixed=None):
|
||||||
|
"""
|
||||||
|
Initializes the AdaptiveActivationFunction module.
|
||||||
|
|
||||||
|
:param callable func: The original collable function. It could be an
|
||||||
|
initialized :meth:`torch.nn.Module`, or a python callable function.
|
||||||
|
:param float | complex alpha: Scaling parameter alpha.
|
||||||
|
Defaults to ``None``. When ``None`` is passed,
|
||||||
|
the variable is initialized to 1.
|
||||||
|
:param float | complex beta: Scaling parameter beta.
|
||||||
|
Defaults to ``None``. When ``None`` is passed,
|
||||||
|
the variable is initialized to 1.
|
||||||
|
:param float | complex gamma: Shifting parameter gamma.
|
||||||
|
Defaults to ``None``. When ``None`` is passed,
|
||||||
|
the variable is initialized to 1.
|
||||||
|
:param list fixed: List of parameters to fix during training,
|
||||||
|
i.e. not optimized (``requires_grad`` set to ``False``).
|
||||||
|
Options are ['alpha', 'beta', 'gamma']. Defaults to None.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# see if there are fixed variables
|
||||||
|
if fixed is not None:
|
||||||
|
check_consistency(fixed, str)
|
||||||
|
if not all(key in ['alpha', 'beta', 'gamma'] for key in fixed):
|
||||||
|
raise TypeError("Fixed keys must be in "
|
||||||
|
"['alpha', 'beta', 'gamma'].")
|
||||||
|
|
||||||
|
# initialize alpha, beta, gamma if they are None
|
||||||
|
if alpha is None:
|
||||||
|
alpha = 1.
|
||||||
|
if beta is None:
|
||||||
|
beta = 1.
|
||||||
|
if gamma is None:
|
||||||
|
gamma = 0.
|
||||||
|
|
||||||
|
# checking consistency
|
||||||
|
check_consistency(alpha, (float, complex))
|
||||||
|
check_consistency(beta, (float, complex))
|
||||||
|
check_consistency(gamma, (float, complex))
|
||||||
|
if not callable(func):
|
||||||
|
raise ValueError("Function must be a callable function.")
|
||||||
|
|
||||||
|
# registering as tensors
|
||||||
|
alpha = torch.tensor(alpha, requires_grad=False)
|
||||||
|
beta = torch.tensor(beta, requires_grad=False)
|
||||||
|
gamma = torch.tensor(gamma, requires_grad=False)
|
||||||
|
|
||||||
|
# setting not fixed variables as torch.nn.Parameter with gradient
|
||||||
|
# registering the buffer for the one which are fixed, buffers by
|
||||||
|
# default are saved alongside trainable parameters
|
||||||
|
if 'alpha' not in (fixed or []):
|
||||||
|
self._alpha = torch.nn.Parameter(alpha, requires_grad=True)
|
||||||
|
else:
|
||||||
|
self.register_buffer('alpha', alpha)
|
||||||
|
|
||||||
|
if 'beta' not in (fixed or []):
|
||||||
|
self._beta = torch.nn.Parameter(beta, requires_grad=True)
|
||||||
|
else:
|
||||||
|
self.register_buffer('beta', beta)
|
||||||
|
|
||||||
|
if 'gamma' not in (fixed or []):
|
||||||
|
self._gamma = torch.nn.Parameter(gamma, requires_grad=True)
|
||||||
|
else:
|
||||||
|
self.register_buffer('gamma', gamma)
|
||||||
|
|
||||||
|
# registering function
|
||||||
|
self._func = func
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
"""
|
||||||
|
return self.alpha * (self._func(self.beta * x + self.gamma))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alpha(self):
|
||||||
|
"""
|
||||||
|
The alpha variable
|
||||||
|
"""
|
||||||
|
return self._alpha
|
||||||
|
|
||||||
|
@property
|
||||||
|
def beta(self):
|
||||||
|
"""
|
||||||
|
The alpha variable
|
||||||
|
"""
|
||||||
|
return self._beta
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gamma(self):
|
||||||
|
"""
|
||||||
|
The alpha variable
|
||||||
|
"""
|
||||||
|
return self._gamma
|
||||||
48
tests/test_layers/test_adaptive_func.py
Normal file
48
tests/test_layers/test_adaptive_func.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina.model.layers.adaptive_func import AdaptiveActivationFunction
|
||||||
|
|
||||||
|
x = torch.rand(5)
|
||||||
|
torchfunc = torch.nn.Tanh()
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
# simple
|
||||||
|
AdaptiveActivationFunction(torchfunc)
|
||||||
|
|
||||||
|
# setting values
|
||||||
|
af = AdaptiveActivationFunction(torchfunc, alpha=1., beta=2., gamma=3.)
|
||||||
|
assert af.alpha.requires_grad
|
||||||
|
assert af.beta.requires_grad
|
||||||
|
assert af.gamma.requires_grad
|
||||||
|
assert af.alpha == 1.
|
||||||
|
assert af.beta == 2.
|
||||||
|
assert af.gamma == 3.
|
||||||
|
|
||||||
|
# fixed variables
|
||||||
|
af = AdaptiveActivationFunction(torchfunc, alpha=1., beta=2.,
|
||||||
|
gamma=3., fixed=['alpha'])
|
||||||
|
assert af.alpha.requires_grad is False
|
||||||
|
assert af.beta.requires_grad
|
||||||
|
assert af.gamma.requires_grad
|
||||||
|
assert af.alpha == 1.
|
||||||
|
assert af.beta == 2.
|
||||||
|
assert af.gamma == 3.
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
AdaptiveActivationFunction(torchfunc, alpha=1., beta=2.,
|
||||||
|
gamma=3., fixed=['delta'])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AdaptiveActivationFunction(torchfunc, alpha='s')
|
||||||
|
AdaptiveActivationFunction(torchfunc, alpha=1., fixed='alpha')
|
||||||
|
AdaptiveActivationFunction(torchfunc, alpha=1)
|
||||||
|
|
||||||
|
def test_forward():
|
||||||
|
af = AdaptiveActivationFunction(torchfunc)
|
||||||
|
af(x)
|
||||||
|
|
||||||
|
def test_backward():
|
||||||
|
af = AdaptiveActivationFunction(torchfunc)
|
||||||
|
y = af(x)
|
||||||
|
y.mean().backward()
|
||||||
Reference in New Issue
Block a user