add singular values in PODBlock

This commit is contained in:
Anna Ivagnes
2025-04-28 17:03:23 +02:00
committed by Dario Coscia
parent e3d4c2fc1a
commit 3c1fed9ae2
2 changed files with 25 additions and 3 deletions

View File

@@ -30,6 +30,7 @@ class PODBlock(torch.nn.Module):
super().__init__() super().__init__()
self.__scale_coefficients = scale_coefficients self.__scale_coefficients = scale_coefficients
self._basis = None self._basis = None
self._singular_values = None
self._scaler = None self._scaler = None
self._rank = rank self._rank = rank
@@ -70,6 +71,19 @@ class PODBlock(torch.nn.Module):
return self._basis[: self.rank] return self._basis[: self.rank]
@property
def singular_values(self):
"""
The singular values of the POD basis.
:return: The singular values.
:rtype: torch.Tensor
"""
if self._singular_values is None:
return None
return self._singular_values[: self.rank]
@property @property
def scaler(self): def scaler(self):
""" """
@@ -136,15 +150,19 @@ class PODBlock(torch.nn.Module):
"This may slow down computations.", "This may slow down computations.",
ResourceWarning, ResourceWarning,
) )
self._basis = torch.svd(X.T)[0].T u, s, v = torch.svd(X.T)
else: else:
if randomized: if randomized:
warnings.warn( warnings.warn(
"Considering a randomized algorithm to compute the POD basis" "Considering a randomized algorithm to compute the POD basis"
) )
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T u, s, v = torch.svd_lowrank(X.T, q=X.shape[0])
else: else:
self._basis = torch.svd(X.T)[0].T u, s, v = torch.svd(X.T)
self._basis = u.T
self._singular_values = s
def forward(self, X): def forward(self, X):
""" """

View File

@@ -23,6 +23,8 @@ def test_fit(rank, scale):
assert pod._basis == None assert pod._basis == None
assert pod.basis == None assert pod.basis == None
assert pod._scaler == None assert pod._scaler == None
assert pod._singular_values == None
assert pod.singular_values == None
assert pod.rank == rank assert pod.rank == rank
assert pod.scale_coefficients == scale assert pod.scale_coefficients == scale
@@ -37,6 +39,8 @@ def test_fit(rank, scale, randomized):
dof = toy_snapshots.shape[1] dof = toy_snapshots.shape[1]
assert pod.basis.shape == (rank, dof) assert pod.basis.shape == (rank, dof)
assert pod._basis.shape == (n_snap, dof) assert pod._basis.shape == (n_snap, dof)
assert pod.singular_values.shape == (rank,)
assert pod._singular_values.shape == (n_snap,)
if scale is True: if scale is True:
assert pod._scaler["mean"].shape == (n_snap,) assert pod._scaler["mean"].shape == (n_snap,)
assert pod._scaler["std"].shape == (n_snap,) assert pod._scaler["std"].shape == (n_snap,)