Backpropagation and fix test for OrthogonalBlock

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
    Co-authored-by: Gabriele Codega <gcodega@pascal.maths.sissa.it>
This commit is contained in:
Dario Coscia
2024-09-03 16:23:14 +02:00
committed by Nicola Demo
parent 59fc19798f
commit eea0cc0833
3 changed files with 125 additions and 25 deletions

View File

@@ -1,6 +1,8 @@
import torch
import pytest
from pina.model.layers.orthogonal import OrthogonalBlock
from pina.model.layers import OrthogonalBlock
torch.manual_seed(111)
list_matrices = [
torch.randn(10, 3),
@@ -10,10 +12,28 @@ list_matrices = [
list_prohibited_matrices_dim0 = list_matrices[:-1]
def test_constructor():
orth = OrthogonalBlock(1)
orth = OrthogonalBlock(0)
orth = OrthogonalBlock()
@pytest.mark.parametrize("dim", [-1, 0, 1, None])
@pytest.mark.parametrize("requires_grad", [True, False, None])
def test_constructor(dim, requires_grad):
if dim is None and requires_grad is None:
block = OrthogonalBlock()
elif dim is None:
block = OrthogonalBlock(requires_grad=requires_grad)
elif requires_grad is None:
block = OrthogonalBlock(dim=dim)
else:
block = OrthogonalBlock(dim=dim, requires_grad=requires_grad)
if dim is not None:
assert block.dim == dim
if requires_grad is not None:
assert block.requires_grad == requires_grad
def test_wrong_constructor():
with pytest.raises(IndexError):
OrthogonalBlock(2)
with pytest.raises(ValueError):
OrthogonalBlock('a')
@pytest.mark.parametrize("V", list_matrices)
def test_forward(V):
@@ -24,6 +44,21 @@ def test_forward(V):
assert torch.allclose(V_orth.T @ V_orth, torch.eye(V.shape[1]), atol=1e-6)
assert torch.allclose(V_orth_row @ V_orth_row.T, torch.eye(V.shape[1]), atol=1e-6)
@pytest.mark.parametrize("V", list_matrices)
def test_backward(V):
orth = OrthogonalBlock(requires_grad=True)
V_orth = orth(V)
loss = V_orth.mean()
loss.backward()
@pytest.mark.parametrize("V", list_matrices)
def test_wrong_backward(V):
orth = OrthogonalBlock(requires_grad=False)
V_orth = orth(V)
loss = V_orth.mean()
with pytest.raises(RuntimeError):
loss.backward()
@pytest.mark.parametrize("V", list_prohibited_matrices_dim0)
def test_forward_prohibited(V):
orth = OrthogonalBlock(0)