From 62d50e2455b510f668dcedca32f9b30f25412d94 Mon Sep 17 00:00:00 2001 From: Anna Ivagnes Date: Fri, 16 Aug 2024 16:48:07 +0200 Subject: [PATCH] add OrthogonalBlock to make input orthonormal --- docs/source/_rst/layers/orthogonal.rst | 7 ++++ pina/model/layers/orthogonal.py | 50 ++++++++++++++++++++++++++ tests/test_layers/test_orthogonal.py | 33 +++++++++++++++++ 3 files changed, 90 insertions(+) create mode 100644 docs/source/_rst/layers/orthogonal.rst create mode 100644 pina/model/layers/orthogonal.py create mode 100644 tests/test_layers/test_orthogonal.py diff --git a/docs/source/_rst/layers/orthogonal.rst b/docs/source/_rst/layers/orthogonal.rst new file mode 100644 index 0000000..6dfc400 --- /dev/null +++ b/docs/source/_rst/layers/orthogonal.rst @@ -0,0 +1,7 @@ +OrthogonalBlock +====================== +.. currentmodule:: pina.model.layers.orthogonal + +.. autoclass:: OrthogonalBlock + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/layers/orthogonal.py b/pina/model/layers/orthogonal.py new file mode 100644 index 0000000..6296502 --- /dev/null +++ b/pina/model/layers/orthogonal.py @@ -0,0 +1,50 @@ +"""Module for OrthogonalBlock layer, to make the input orthonormal.""" + +import torch + + +class OrthogonalBlock(torch.nn.Module): + """ + Module to make the input orthonormal. + The module takes a tensor of size [N, M] and returns a tensor of + size [N, M] where the columns are orthonormal. + """ + def __init__(self, dim=-1): + """ + Initialize the OrthogonalBlock module. + + :param int dim: The dimension where to orthogonalize. + """ + super().__init__() + self.dim = dim + + def forward(self, X): + """ + Forward pass of the OrthogonalBlock module using a Gram-Schmidt + algorithm. + + :raises Warning: If the dimension is greater than the other dimensions. + + :param torch.Tensor X: The input tensor to orthogonalize. + :return: The orthonormal tensor. + """ + # check dim is less than all the other dimensions + if X.shape[self.dim] > min(X.shape): + raise Warning("The dimension where to orthogonalize is greater\ + than the other dimensions") + + result = torch.zeros_like(X) + # normalize first basis + X_0 = torch.select(X, self.dim, 0) + result_0 = torch.select(result, self.dim, 0) + result_0 += X_0/torch.norm(X_0) + # iterate over the rest of the basis with Gram-Schmidt + for i in range(1, X.shape[self.dim]): + v = torch.select(X, self.dim, i) + for j in range(i): + v -= torch.sum(v * torch.select(result, self.dim, j), + dim=self.dim, keepdim=True) * torch.select( + result, self.dim, j) + result_i = torch.select(result, self.dim, i) + result_i += v/torch.norm(v) + return result \ No newline at end of file diff --git a/tests/test_layers/test_orthogonal.py b/tests/test_layers/test_orthogonal.py new file mode 100644 index 0000000..30b59cd --- /dev/null +++ b/tests/test_layers/test_orthogonal.py @@ -0,0 +1,33 @@ +import torch +import pytest +from pina.model.layers.orthogonal import OrthogonalBlock + +list_matrices = [ + torch.randn(10, 3), + torch.rand(100, 5), + torch.randn(5, 5), + ] + +list_prohibited_matrices_dim0 = list_matrices[:-1] + +def test_constructor(): + orth = OrthogonalBlock(1) + orth = OrthogonalBlock(0) + orth = OrthogonalBlock() + +@pytest.mark.parametrize("V", list_matrices) +def test_forward(V): + orth = OrthogonalBlock() + orth_row = OrthogonalBlock(0) + V_orth = orth(V) + V_orth_row = orth_row(V.T) + assert torch.allclose(V_orth.T @ V_orth, torch.eye(V.shape[1]), atol=1e-6) + assert torch.allclose(V_orth_row @ V_orth_row.T, torch.eye(V.shape[1]), atol=1e-6) + +@pytest.mark.parametrize("V", list_prohibited_matrices_dim0) +def test_forward_prohibited(V): + orth = OrthogonalBlock(0) + with pytest.raises(Warning): + V_orth = orth(V) + assert V.shape[0] > V.shape[1] +