Change SVD type in pod.py (#449)

* Change SVD type in pod.py
This commit is contained in:
Anna Ivagnes
2025-02-14 19:13:09 +01:00
committed by GitHub
parent d94256fac4
commit c28a4bd2b9
2 changed files with 22 additions and 11 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod
import torch import torch
from .stride import Stride from .stride import Stride
from .utils_convolution import optimizing from .utils_convolution import optimizing
import warnings
class PODBlock(torch.nn.Module): class PODBlock(torch.nn.Module):
""" """
@@ -85,7 +85,7 @@ class PODBlock(torch.nn.Module):
""" """
return self.__scale_coefficients return self.__scale_coefficients
def fit(self, X): def fit(self, X, randomized=True):
""" """
Set the POD basis by performing the singular value decomposition of the Set the POD basis by performing the singular value decomposition of the
given tensor. If `self.scale_coefficients` is True, the coefficients given tensor. If `self.scale_coefficients` is True, the coefficients
@@ -93,7 +93,7 @@ class PODBlock(torch.nn.Module):
:param torch.Tensor X: The tensor to be reduced. :param torch.Tensor X: The tensor to be reduced.
""" """
self._fit_pod(X) self._fit_pod(X, randomized)
if self.__scale_coefficients: if self.__scale_coefficients:
self._fit_scaler(torch.matmul(self._basis, X.T)) self._fit_scaler(torch.matmul(self._basis, X.T))
@@ -112,16 +112,24 @@ class PODBlock(torch.nn.Module):
"mean": torch.mean(coeffs, dim=1), "mean": torch.mean(coeffs, dim=1),
} }
def _fit_pod(self, X): def _fit_pod(self, X, randomized):
""" """
Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`. Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`.
:param torch.Tensor X: The tensor to be reduced. :param torch.Tensor X: The tensor to be reduced.
""" """
if X.device.type == "mps": # svd_lowrank not arailable for mps if X.device.type == "mps": # svd_lowrank not arailable for mps
warnings.warn(
"svd_lowrank not available for mps, using svd instead."
"This may slow down computations.", ResourceWarning
)
self._basis = torch.svd(X.T)[0].T self._basis = torch.svd(X.T)[0].T
else: else:
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T if randomized:
warnings.warn("Considering a randomized algorithm to compute the POD basis")
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
else:
self._basis = torch.svd(X.T)[0].T
def forward(self, X): def forward(self, X):
""" """

View File

@@ -25,9 +25,10 @@ def test_fit(rank, scale):
@pytest.mark.parametrize("scale", [True, False]) @pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10]) @pytest.mark.parametrize("rank", [1, 2, 10])
def test_fit(rank, scale): @pytest.mark.parametrize("randomized", [True, False])
def test_fit(rank, scale, randomized):
pod = PODBlock(rank, scale) pod = PODBlock(rank, scale)
pod.fit(toy_snapshots) pod.fit(toy_snapshots, randomized)
n_snap = toy_snapshots.shape[0] n_snap = toy_snapshots.shape[0]
dof = toy_snapshots.shape[1] dof = toy_snapshots.shape[1]
assert pod.basis.shape == (rank, dof) assert pod.basis.shape == (rank, dof)
@@ -65,18 +66,20 @@ def test_forward():
@pytest.mark.parametrize("scale", [True, False]) @pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10]) @pytest.mark.parametrize("rank", [1, 2, 10])
def test_expand(rank, scale): @pytest.mark.parametrize("randomized", [True, False])
def test_expand(rank, scale, randomized):
pod = PODBlock(rank, scale) pod = PODBlock(rank, scale)
pod.fit(toy_snapshots) pod.fit(toy_snapshots, randomized)
c = pod(toy_snapshots) c = pod(toy_snapshots)
torch.testing.assert_close(pod.expand(c), toy_snapshots) torch.testing.assert_close(pod.expand(c), toy_snapshots)
torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0)) torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0))
@pytest.mark.parametrize("scale", [True, False]) @pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10]) @pytest.mark.parametrize("rank", [1, 2, 10])
def test_reduce_expand(rank, scale): @pytest.mark.parametrize("randomized", [True, False])
def test_reduce_expand(rank, scale, randomized):
pod = PODBlock(rank, scale) pod = PODBlock(rank, scale)
pod.fit(toy_snapshots) pod.fit(toy_snapshots, randomized)
torch.testing.assert_close( torch.testing.assert_close(
pod.expand(pod.reduce(toy_snapshots)), pod.expand(pod.reduce(toy_snapshots)),
toy_snapshots) toy_snapshots)