Change SVD type in pod.py (#449)

* Change SVD type in pod.py
This commit is contained in:
Anna Ivagnes
2025-02-14 19:13:09 +01:00
committed by GitHub
parent d94256fac4
commit c28a4bd2b9
2 changed files with 22 additions and 11 deletions

View File

@@ -25,9 +25,10 @@ def test_fit(rank, scale):
@pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10])
def test_fit(rank, scale):
@pytest.mark.parametrize("randomized", [True, False])
def test_fit(rank, scale, randomized):
pod = PODBlock(rank, scale)
pod.fit(toy_snapshots)
pod.fit(toy_snapshots, randomized)
n_snap = toy_snapshots.shape[0]
dof = toy_snapshots.shape[1]
assert pod.basis.shape == (rank, dof)
@@ -65,18 +66,20 @@ def test_forward():
@pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10])
def test_expand(rank, scale):
@pytest.mark.parametrize("randomized", [True, False])
def test_expand(rank, scale, randomized):
pod = PODBlock(rank, scale)
pod.fit(toy_snapshots)
pod.fit(toy_snapshots, randomized)
c = pod(toy_snapshots)
torch.testing.assert_close(pod.expand(c), toy_snapshots)
torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0))
@pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10])
def test_reduce_expand(rank, scale):
@pytest.mark.parametrize("randomized", [True, False])
def test_reduce_expand(rank, scale, randomized):
pod = PODBlock(rank, scale)
pod.fit(toy_snapshots)
pod.fit(toy_snapshots, randomized)
torch.testing.assert_close(
pod.expand(pod.reduce(toy_snapshots)),
toy_snapshots)