@@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from .stride import Stride
|
||||
from .utils_convolution import optimizing
|
||||
|
||||
import warnings
|
||||
|
||||
class PODBlock(torch.nn.Module):
|
||||
"""
|
||||
@@ -85,7 +85,7 @@ class PODBlock(torch.nn.Module):
|
||||
"""
|
||||
return self.__scale_coefficients
|
||||
|
||||
def fit(self, X):
|
||||
def fit(self, X, randomized=True):
|
||||
"""
|
||||
Set the POD basis by performing the singular value decomposition of the
|
||||
given tensor. If `self.scale_coefficients` is True, the coefficients
|
||||
@@ -93,7 +93,7 @@ class PODBlock(torch.nn.Module):
|
||||
|
||||
:param torch.Tensor X: The tensor to be reduced.
|
||||
"""
|
||||
self._fit_pod(X)
|
||||
self._fit_pod(X, randomized)
|
||||
|
||||
if self.__scale_coefficients:
|
||||
self._fit_scaler(torch.matmul(self._basis, X.T))
|
||||
@@ -112,16 +112,24 @@ class PODBlock(torch.nn.Module):
|
||||
"mean": torch.mean(coeffs, dim=1),
|
||||
}
|
||||
|
||||
def _fit_pod(self, X):
|
||||
def _fit_pod(self, X, randomized):
|
||||
"""
|
||||
Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`.
|
||||
|
||||
:param torch.Tensor X: The tensor to be reduced.
|
||||
"""
|
||||
if X.device.type == "mps": # svd_lowrank not arailable for mps
|
||||
warnings.warn(
|
||||
"svd_lowrank not available for mps, using svd instead."
|
||||
"This may slow down computations.", ResourceWarning
|
||||
)
|
||||
self._basis = torch.svd(X.T)[0].T
|
||||
else:
|
||||
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
|
||||
if randomized:
|
||||
warnings.warn("Considering a randomized algorithm to compute the POD basis")
|
||||
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
|
||||
else:
|
||||
self._basis = torch.svd(X.T)[0].T
|
||||
|
||||
def forward(self, X):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user