Files
PINA/tests/test_adaptive_function.py
Filippo Olivo 4177bfbb50 Fix Codacy Warnings (#477)
---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
2025-03-19 17:48:18 +01:00

86 lines
1.8 KiB
Python

import torch
import pytest
from pina.adaptive_function import (
AdaptiveReLU,
AdaptiveSigmoid,
AdaptiveTanh,
AdaptiveSiLU,
AdaptiveMish,
AdaptiveELU,
AdaptiveCELU,
AdaptiveGELU,
AdaptiveSoftmin,
AdaptiveSoftmax,
AdaptiveSIREN,
AdaptiveExp,
)
adaptive_function = (
AdaptiveReLU,
AdaptiveSigmoid,
AdaptiveTanh,
AdaptiveSiLU,
AdaptiveMish,
AdaptiveELU,
AdaptiveCELU,
AdaptiveGELU,
AdaptiveSoftmin,
AdaptiveSoftmax,
AdaptiveSIREN,
AdaptiveExp,
)
x = torch.rand(10, requires_grad=True)
@pytest.mark.parametrize("Func", adaptive_function)
def test_constructor(Func):
if Func.__name__ == "AdaptiveExp":
# simple
Func()
# setting values
af = Func(alpha=1.0, beta=2.0)
assert af.alpha.requires_grad
assert af.beta.requires_grad
assert af.alpha == 1.0
assert af.beta == 2.0
else:
# simple
Func()
# setting values
af = Func(alpha=1.0, beta=2.0, gamma=3.0)
assert af.alpha.requires_grad
assert af.beta.requires_grad
assert af.gamma.requires_grad
assert af.alpha == 1.0
assert af.beta == 2.0
assert af.gamma == 3.0
# fixed variables
af = Func(alpha=1.0, beta=2.0, fixed=["alpha"])
assert af.alpha.requires_grad is False
assert af.beta.requires_grad
assert af.alpha == 1.0
assert af.beta == 2.0
with pytest.raises(TypeError):
Func(alpha=1.0, beta=2.0, fixed=["delta"])
with pytest.raises(ValueError):
Func(alpha="s")
Func(alpha=1)
@pytest.mark.parametrize("Func", adaptive_function)
def test_forward(Func):
af = Func()
af(x)
@pytest.mark.parametrize("Func", adaptive_function)
def test_backward(Func):
af = Func()
y = af(x)
y.mean().backward()