From 4a6e73fa54bcdcada1ce513957cb88c6a9024324 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Fri, 19 Sep 2025 13:42:32 +0200 Subject: [PATCH] Fix basis device transfer in PODBlock (#650) * fix gpu data moving --- pina/model/block/pod_block.py | 28 ++++++++++++++++------------ tests/test_block/test_pod.py | 7 ++++--- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/pina/model/block/pod_block.py b/pina/model/block/pod_block.py index 0c7990d..5ea2a35 100644 --- a/pina/model/block/pod_block.py +++ b/pina/model/block/pod_block.py @@ -1,7 +1,7 @@ """Module for Base Continuous Convolution class.""" -import torch import warnings +import torch class PODBlock(torch.nn.Module): @@ -29,9 +29,10 @@ class PODBlock(torch.nn.Module): """ super().__init__() self.__scale_coefficients = scale_coefficients - self._basis = None + self.register_buffer("_basis", None) self._singular_values = None - self._scaler = None + self.register_buffer("_std", None) + self.register_buffer("_mean", None) self._rank = rank @property @@ -94,12 +95,12 @@ class PODBlock(torch.nn.Module): :return: The scaler dictionary. :rtype: dict """ - if self._scaler is None: + if self._std is None: return None return { - "mean": self._scaler["mean"][: self.rank], - "std": self._scaler["std"][: self.rank], + "mean": self._mean[: self.rank], + "std": self._std[: self.rank], } @property @@ -119,6 +120,10 @@ class PODBlock(torch.nn.Module): are scaled after the projection to have zero mean and unit variance. :param torch.Tensor X: The input tensor to be reduced. + :param bool randomized: If ``True``, a randomized algorithm is used to + compute the POD basis. In general, this leads to faster + computations, but the results may be less accurate. Default is + ``True``. """ self._fit_pod(X, randomized) @@ -132,10 +137,8 @@ class PODBlock(torch.nn.Module): :param torch.Tensor coeffs: The coefficients to be scaled. """ - self._scaler = { - "std": torch.std(coeffs, dim=1), - "mean": torch.mean(coeffs, dim=1), - } + self._std = torch.std(coeffs, dim=1) # pylint: disable=W0201 + self._mean = torch.mean(coeffs, dim=1) # pylint: disable=W0201 def _fit_pod(self, X, randomized): """ @@ -154,13 +157,14 @@ class PODBlock(torch.nn.Module): else: if randomized: warnings.warn( - "Considering a randomized algorithm to compute the POD basis" + "Considering a randomized algorithm to compute the POD " + "basis" ) u, s, _ = torch.svd_lowrank(X.T, q=X.shape[0]) else: u, s, _ = torch.svd(X.T) - self._basis = u.T + self._basis = u.T # pylint: disable=W0201 self._singular_values = s def forward(self, X): diff --git a/tests/test_block/test_pod.py b/tests/test_block/test_pod.py index 8cee923..d10625f 100644 --- a/tests/test_block/test_pod.py +++ b/tests/test_block/test_pod.py @@ -42,13 +42,14 @@ def test_fit(rank, scale, randomized): assert pod.singular_values.shape == (rank,) assert pod._singular_values.shape == (n_snap,) if scale is True: - assert pod._scaler["mean"].shape == (n_snap,) - assert pod._scaler["std"].shape == (n_snap,) + assert pod._mean.shape == (n_snap,) + assert pod._std.shape == (n_snap,) assert pod.scaler["mean"].shape == (rank,) assert pod.scaler["std"].shape == (rank,) assert pod.scaler["mean"].shape[0] == pod.basis.shape[0] else: - assert pod._scaler == None + assert pod._std == None + assert pod._mean == None assert pod.scaler == None