From 3c1fed9ae234e94e5976bed3ee86ed93d5609438 Mon Sep 17 00:00:00 2001 From: Anna Ivagnes Date: Mon, 28 Apr 2025 17:03:23 +0200 Subject: [PATCH] add singular values in PODBlock --- pina/model/block/pod_block.py | 24 +++++++++++++++++++++--- tests/test_blocks/test_pod.py | 4 ++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/pina/model/block/pod_block.py b/pina/model/block/pod_block.py index 290cb0d..cf38f19 100644 --- a/pina/model/block/pod_block.py +++ b/pina/model/block/pod_block.py @@ -30,6 +30,7 @@ class PODBlock(torch.nn.Module): super().__init__() self.__scale_coefficients = scale_coefficients self._basis = None + self._singular_values = None self._scaler = None self._rank = rank @@ -70,6 +71,19 @@ class PODBlock(torch.nn.Module): return self._basis[: self.rank] + @property + def singular_values(self): + """ + The singular values of the POD basis. + + :return: The singular values. + :rtype: torch.Tensor + """ + if self._singular_values is None: + return None + + return self._singular_values[: self.rank] + @property def scaler(self): """ @@ -136,15 +150,19 @@ class PODBlock(torch.nn.Module): "This may slow down computations.", ResourceWarning, ) - self._basis = torch.svd(X.T)[0].T + u, s, v = torch.svd(X.T) else: if randomized: warnings.warn( "Considering a randomized algorithm to compute the POD basis" ) - self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T + u, s, v = torch.svd_lowrank(X.T, q=X.shape[0]) + else: - self._basis = torch.svd(X.T)[0].T + u, s, v = torch.svd(X.T) + self._basis = u.T + self._singular_values = s + def forward(self, X): """ diff --git a/tests/test_blocks/test_pod.py b/tests/test_blocks/test_pod.py index a0823bc..8cee923 100644 --- a/tests/test_blocks/test_pod.py +++ b/tests/test_blocks/test_pod.py @@ -23,6 +23,8 @@ def test_fit(rank, scale): assert pod._basis == None assert pod.basis == None assert pod._scaler == None + assert pod._singular_values == None + assert pod.singular_values == None assert pod.rank == rank assert pod.scale_coefficients == scale @@ -37,6 +39,8 @@ def test_fit(rank, scale, randomized): dof = toy_snapshots.shape[1] assert pod.basis.shape == (rank, dof) assert pod._basis.shape == (n_snap, dof) + assert pod.singular_values.shape == (rank,) + assert pod._singular_values.shape == (n_snap,) if scale is True: assert pod._scaler["mean"].shape == (n_snap,) assert pod._scaler["std"].shape == (n_snap,)