Fix basis device transfer in PODBlock (#650)

* fix gpu data moving
This commit is contained in:
Filippo Olivo
2025-09-19 13:42:32 +02:00
committed by GitHub
parent 87c5c6a674
commit 4a6e73fa54
2 changed files with 20 additions and 15 deletions

View File

@@ -42,13 +42,14 @@ def test_fit(rank, scale, randomized):
assert pod.singular_values.shape == (rank,)
assert pod._singular_values.shape == (n_snap,)
if scale is True:
assert pod._scaler["mean"].shape == (n_snap,)
assert pod._scaler["std"].shape == (n_snap,)
assert pod._mean.shape == (n_snap,)
assert pod._std.shape == (n_snap,)
assert pod.scaler["mean"].shape == (rank,)
assert pod.scaler["std"].shape == (rank,)
assert pod.scaler["mean"].shape[0] == pod.basis.shape[0]
else:
assert pod._scaler == None
assert pod._std == None
assert pod._mean == None
assert pod.scaler == None