@@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod
|
|||||||
import torch
|
import torch
|
||||||
from .stride import Stride
|
from .stride import Stride
|
||||||
from .utils_convolution import optimizing
|
from .utils_convolution import optimizing
|
||||||
|
import warnings
|
||||||
|
|
||||||
class PODBlock(torch.nn.Module):
|
class PODBlock(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -85,7 +85,7 @@ class PODBlock(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
return self.__scale_coefficients
|
return self.__scale_coefficients
|
||||||
|
|
||||||
def fit(self, X):
|
def fit(self, X, randomized=True):
|
||||||
"""
|
"""
|
||||||
Set the POD basis by performing the singular value decomposition of the
|
Set the POD basis by performing the singular value decomposition of the
|
||||||
given tensor. If `self.scale_coefficients` is True, the coefficients
|
given tensor. If `self.scale_coefficients` is True, the coefficients
|
||||||
@@ -93,7 +93,7 @@ class PODBlock(torch.nn.Module):
|
|||||||
|
|
||||||
:param torch.Tensor X: The tensor to be reduced.
|
:param torch.Tensor X: The tensor to be reduced.
|
||||||
"""
|
"""
|
||||||
self._fit_pod(X)
|
self._fit_pod(X, randomized)
|
||||||
|
|
||||||
if self.__scale_coefficients:
|
if self.__scale_coefficients:
|
||||||
self._fit_scaler(torch.matmul(self._basis, X.T))
|
self._fit_scaler(torch.matmul(self._basis, X.T))
|
||||||
@@ -112,16 +112,24 @@ class PODBlock(torch.nn.Module):
|
|||||||
"mean": torch.mean(coeffs, dim=1),
|
"mean": torch.mean(coeffs, dim=1),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _fit_pod(self, X):
|
def _fit_pod(self, X, randomized):
|
||||||
"""
|
"""
|
||||||
Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`.
|
Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`.
|
||||||
|
|
||||||
:param torch.Tensor X: The tensor to be reduced.
|
:param torch.Tensor X: The tensor to be reduced.
|
||||||
"""
|
"""
|
||||||
if X.device.type == "mps": # svd_lowrank not arailable for mps
|
if X.device.type == "mps": # svd_lowrank not arailable for mps
|
||||||
|
warnings.warn(
|
||||||
|
"svd_lowrank not available for mps, using svd instead."
|
||||||
|
"This may slow down computations.", ResourceWarning
|
||||||
|
)
|
||||||
self._basis = torch.svd(X.T)[0].T
|
self._basis = torch.svd(X.T)[0].T
|
||||||
else:
|
else:
|
||||||
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
|
if randomized:
|
||||||
|
warnings.warn("Considering a randomized algorithm to compute the POD basis")
|
||||||
|
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
|
||||||
|
else:
|
||||||
|
self._basis = torch.svd(X.T)[0].T
|
||||||
|
|
||||||
def forward(self, X):
|
def forward(self, X):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ def test_fit(rank, scale):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("scale", [True, False])
|
@pytest.mark.parametrize("scale", [True, False])
|
||||||
@pytest.mark.parametrize("rank", [1, 2, 10])
|
@pytest.mark.parametrize("rank", [1, 2, 10])
|
||||||
def test_fit(rank, scale):
|
@pytest.mark.parametrize("randomized", [True, False])
|
||||||
|
def test_fit(rank, scale, randomized):
|
||||||
pod = PODBlock(rank, scale)
|
pod = PODBlock(rank, scale)
|
||||||
pod.fit(toy_snapshots)
|
pod.fit(toy_snapshots, randomized)
|
||||||
n_snap = toy_snapshots.shape[0]
|
n_snap = toy_snapshots.shape[0]
|
||||||
dof = toy_snapshots.shape[1]
|
dof = toy_snapshots.shape[1]
|
||||||
assert pod.basis.shape == (rank, dof)
|
assert pod.basis.shape == (rank, dof)
|
||||||
@@ -65,18 +66,20 @@ def test_forward():
|
|||||||
|
|
||||||
@pytest.mark.parametrize("scale", [True, False])
|
@pytest.mark.parametrize("scale", [True, False])
|
||||||
@pytest.mark.parametrize("rank", [1, 2, 10])
|
@pytest.mark.parametrize("rank", [1, 2, 10])
|
||||||
def test_expand(rank, scale):
|
@pytest.mark.parametrize("randomized", [True, False])
|
||||||
|
def test_expand(rank, scale, randomized):
|
||||||
pod = PODBlock(rank, scale)
|
pod = PODBlock(rank, scale)
|
||||||
pod.fit(toy_snapshots)
|
pod.fit(toy_snapshots, randomized)
|
||||||
c = pod(toy_snapshots)
|
c = pod(toy_snapshots)
|
||||||
torch.testing.assert_close(pod.expand(c), toy_snapshots)
|
torch.testing.assert_close(pod.expand(c), toy_snapshots)
|
||||||
torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0))
|
torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0))
|
||||||
|
|
||||||
@pytest.mark.parametrize("scale", [True, False])
|
@pytest.mark.parametrize("scale", [True, False])
|
||||||
@pytest.mark.parametrize("rank", [1, 2, 10])
|
@pytest.mark.parametrize("rank", [1, 2, 10])
|
||||||
def test_reduce_expand(rank, scale):
|
@pytest.mark.parametrize("randomized", [True, False])
|
||||||
|
def test_reduce_expand(rank, scale, randomized):
|
||||||
pod = PODBlock(rank, scale)
|
pod = PODBlock(rank, scale)
|
||||||
pod.fit(toy_snapshots)
|
pod.fit(toy_snapshots, randomized)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
pod.expand(pod.reduce(toy_snapshots)),
|
pod.expand(pod.reduce(toy_snapshots)),
|
||||||
toy_snapshots)
|
toy_snapshots)
|
||||||
|
|||||||
Reference in New Issue
Block a user