Fix basis device transfer in PODBlock (#650)

* fix gpu data moving
This commit is contained in:
Filippo Olivo
2025-09-19 13:42:32 +02:00
committed by GitHub
parent 87c5c6a674
commit 4a6e73fa54
2 changed files with 20 additions and 15 deletions

View File

@@ -1,7 +1,7 @@
"""Module for Base Continuous Convolution class."""
import torch
import warnings
import torch
class PODBlock(torch.nn.Module):
@@ -29,9 +29,10 @@ class PODBlock(torch.nn.Module):
"""
super().__init__()
self.__scale_coefficients = scale_coefficients
self._basis = None
self.register_buffer("_basis", None)
self._singular_values = None
self._scaler = None
self.register_buffer("_std", None)
self.register_buffer("_mean", None)
self._rank = rank
@property
@@ -94,12 +95,12 @@ class PODBlock(torch.nn.Module):
:return: The scaler dictionary.
:rtype: dict
"""
if self._scaler is None:
if self._std is None:
return None
return {
"mean": self._scaler["mean"][: self.rank],
"std": self._scaler["std"][: self.rank],
"mean": self._mean[: self.rank],
"std": self._std[: self.rank],
}
@property
@@ -119,6 +120,10 @@ class PODBlock(torch.nn.Module):
are scaled after the projection to have zero mean and unit variance.
:param torch.Tensor X: The input tensor to be reduced.
:param bool randomized: If ``True``, a randomized algorithm is used to
compute the POD basis. In general, this leads to faster
computations, but the results may be less accurate. Default is
``True``.
"""
self._fit_pod(X, randomized)
@@ -132,10 +137,8 @@ class PODBlock(torch.nn.Module):
:param torch.Tensor coeffs: The coefficients to be scaled.
"""
self._scaler = {
"std": torch.std(coeffs, dim=1),
"mean": torch.mean(coeffs, dim=1),
}
self._std = torch.std(coeffs, dim=1) # pylint: disable=W0201
self._mean = torch.mean(coeffs, dim=1) # pylint: disable=W0201
def _fit_pod(self, X, randomized):
"""
@@ -154,13 +157,14 @@ class PODBlock(torch.nn.Module):
else:
if randomized:
warnings.warn(
"Considering a randomized algorithm to compute the POD basis"
"Considering a randomized algorithm to compute the POD "
"basis"
)
u, s, _ = torch.svd_lowrank(X.T, q=X.shape[0])
else:
u, s, _ = torch.svd(X.T)
self._basis = u.T
self._basis = u.T # pylint: disable=W0201
self._singular_values = s
def forward(self, X):