Backpropagation and fix test for OrthogonalBlock

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
    Co-authored-by: Gabriele Codega <gcodega@pascal.maths.sissa.it>
This commit is contained in:
Dario Coscia
2024-09-03 16:23:14 +02:00
committed by Nicola Demo
parent 59fc19798f
commit eea0cc0833
3 changed files with 125 additions and 25 deletions

View File

@@ -9,6 +9,7 @@ __all__ = [
"FourierBlock2D", "FourierBlock2D",
"FourierBlock3D", "FourierBlock3D",
"PODBlock", "PODBlock",
"OrthogonalBlock",
"PeriodicBoundaryEmbedding", "PeriodicBoundaryEmbedding",
"FourierFeatureEmbedding", "FourierFeatureEmbedding",
"AVNOBlock", "AVNOBlock",
@@ -25,6 +26,7 @@ from .spectral import (
) )
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
from .pod import PODBlock from .pod import PODBlock
from .orthogonal import OrthogonalBlock
from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding
from .avno_layer import AVNOBlock from .avno_layer import AVNOBlock
from .lowrank_layer import LowRankBlock from .lowrank_layer import LowRankBlock

View File

@@ -1,23 +1,33 @@
"""Module for OrthogonalBlock layer, to make the input orthonormal.""" """Module for OrthogonalBlock."""
import torch import torch
from ...utils import check_consistency
class OrthogonalBlock(torch.nn.Module): class OrthogonalBlock(torch.nn.Module):
""" """
Module to make the input orthonormal. Module to make the input orthonormal.
The module takes a tensor of size [N, M] and returns a tensor of The module takes a tensor of size :math:`[N, M]` and returns a tensor of
size [N, M] where the columns are orthonormal. size :math:`[N, M]` where the columns are orthonormal. The block performs a
Gram Schmidt orthogonalization process for the input, see
`here <https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process>` for
details.
""" """
def __init__(self, dim=-1): def __init__(self, dim=-1, requires_grad=True):
""" """
Initialize the OrthogonalBlock module. Initialize the OrthogonalBlock module.
:param int dim: The dimension where to orthogonalize. :param int dim: The dimension where to orthogonalize.
:param bool requires_grad: If autograd should record operations on
the returned tensor, defaults to True.
""" """
super().__init__() super().__init__()
# store dim
self.dim = dim self.dim = dim
# store requires_grad
check_consistency(requires_grad, bool)
self._requires_grad = requires_grad
def forward(self, X): def forward(self, X):
""" """
@@ -26,7 +36,8 @@ class OrthogonalBlock(torch.nn.Module):
:raises Warning: If the dimension is greater than the other dimensions. :raises Warning: If the dimension is greater than the other dimensions.
:param torch.Tensor X: The input tensor to orthogonalize. :param torch.Tensor X: The input tensor to orthogonalize. The input must
be of dimensions :math:`[N, M]`.
:return: The orthonormal tensor. :return: The orthonormal tensor.
""" """
# check dim is less than all the other dimensions # check dim is less than all the other dimensions
@@ -36,23 +47,75 @@ class OrthogonalBlock(torch.nn.Module):
" than the other dimensions" " than the other dimensions"
) )
result = torch.zeros_like(X) result = torch.zeros_like(X, requires_grad=self._requires_grad)
X_0 = torch.select(X, self.dim, 0).clone()
# normalize first basis result_0 = X_0/torch.linalg.norm(X_0)
X_0 = torch.select(X, self.dim, 0) result = self._differentiable_copy(result, 0, result_0)
result_0 = torch.select(result, self.dim, 0)
result_0 += X_0 / torch.norm(X_0)
# iterate over the rest of the basis with Gram-Schmidt # iterate over the rest of the basis with Gram-Schmidt
for i in range(1, X.shape[self.dim]): for i in range(1, X.shape[self.dim]):
v = torch.select(X, self.dim, i) v = torch.select(X, self.dim, i).clone()
for j in range(i): for j in range(i):
v -= torch.sum( vj = torch.select(result,self.dim,j).clone()
v * torch.select(result, self.dim, j), v = v - torch.sum(v * vj,
dim=self.dim, dim=self.dim, keepdim=True) * vj
keepdim=True, #result_i = torch.select(result, self.dim, i)
) * torch.select(result, self.dim, j) result_i = v/torch.linalg.norm(v)
result_i = torch.select(result, self.dim, i) result = self._differentiable_copy(result, i, result_i)
result_i += v / torch.norm(v)
return result return result
def _differentiable_copy(self, result, idx, value):
"""
Perform a differentiable copy operation on a tensor.
:param torch.Tensor result: The tensor where values will be copied to.
:param int idx: The index along the specified dimension where the
value will be copied.
:param torch.Tensor value: The tensor value to copy into the
result tensor.
:return: A new tensor with the copied values.
:rtype: torch.Tensor
"""
return result.index_copy(
self.dim, torch.tensor([idx]), value.unsqueeze(self.dim)
)
@property
def dim(self):
"""
Get the dimension along which operations are performed.
:return: The current dimension value.
:rtype: int
"""
return self._dim
@dim.setter
def dim(self, value):
"""
Set the dimension along which operations are performed.
:param value: The dimension to be set, which must be 0, 1, or -1.
:type value: int
:raises IndexError: If the provided dimension is not in the
range [-1, 1].
"""
# check consistency
check_consistency(value, int)
if value not in [0, 1, -1]:
raise IndexError('Dimension out of range (expected to be in '
f'range of [-1, 1], but got {value})')
# assign value
self._dim = value
@property
def requires_grad(self):
"""
Indicates whether gradient computation is required for operations
on the tensors.
:return: True if gradients are required, False otherwise.
:rtype: bool
"""
return self._requires_grad

View File

@@ -1,6 +1,8 @@
import torch import torch
import pytest import pytest
from pina.model.layers.orthogonal import OrthogonalBlock from pina.model.layers import OrthogonalBlock
torch.manual_seed(111)
list_matrices = [ list_matrices = [
torch.randn(10, 3), torch.randn(10, 3),
@@ -10,10 +12,28 @@ list_matrices = [
list_prohibited_matrices_dim0 = list_matrices[:-1] list_prohibited_matrices_dim0 = list_matrices[:-1]
def test_constructor(): @pytest.mark.parametrize("dim", [-1, 0, 1, None])
orth = OrthogonalBlock(1) @pytest.mark.parametrize("requires_grad", [True, False, None])
orth = OrthogonalBlock(0) def test_constructor(dim, requires_grad):
orth = OrthogonalBlock() if dim is None and requires_grad is None:
block = OrthogonalBlock()
elif dim is None:
block = OrthogonalBlock(requires_grad=requires_grad)
elif requires_grad is None:
block = OrthogonalBlock(dim=dim)
else:
block = OrthogonalBlock(dim=dim, requires_grad=requires_grad)
if dim is not None:
assert block.dim == dim
if requires_grad is not None:
assert block.requires_grad == requires_grad
def test_wrong_constructor():
with pytest.raises(IndexError):
OrthogonalBlock(2)
with pytest.raises(ValueError):
OrthogonalBlock('a')
@pytest.mark.parametrize("V", list_matrices) @pytest.mark.parametrize("V", list_matrices)
def test_forward(V): def test_forward(V):
@@ -24,6 +44,21 @@ def test_forward(V):
assert torch.allclose(V_orth.T @ V_orth, torch.eye(V.shape[1]), atol=1e-6) assert torch.allclose(V_orth.T @ V_orth, torch.eye(V.shape[1]), atol=1e-6)
assert torch.allclose(V_orth_row @ V_orth_row.T, torch.eye(V.shape[1]), atol=1e-6) assert torch.allclose(V_orth_row @ V_orth_row.T, torch.eye(V.shape[1]), atol=1e-6)
@pytest.mark.parametrize("V", list_matrices)
def test_backward(V):
orth = OrthogonalBlock(requires_grad=True)
V_orth = orth(V)
loss = V_orth.mean()
loss.backward()
@pytest.mark.parametrize("V", list_matrices)
def test_wrong_backward(V):
orth = OrthogonalBlock(requires_grad=False)
V_orth = orth(V)
loss = V_orth.mean()
with pytest.raises(RuntimeError):
loss.backward()
@pytest.mark.parametrize("V", list_prohibited_matrices_dim0) @pytest.mark.parametrize("V", list_prohibited_matrices_dim0)
def test_forward_prohibited(V): def test_forward_prohibited(V):
orth = OrthogonalBlock(0) orth = OrthogonalBlock(0)