From c65631705548892124a793ece85a30a65dae6002 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Wed, 29 Nov 2023 17:14:15 +0100 Subject: [PATCH] pod layer --- pina/model/layers/pod.py | 177 ++++++++++++++++++++++++++++++++++ tests/test_layers/test_pod.py | 86 +++++++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 pina/model/layers/pod.py create mode 100644 tests/test_layers/test_pod.py diff --git a/pina/model/layers/pod.py b/pina/model/layers/pod.py new file mode 100644 index 0000000..299b0ac --- /dev/null +++ b/pina/model/layers/pod.py @@ -0,0 +1,177 @@ +"""Module for Base Continuous Convolution class.""" +from abc import ABCMeta, abstractmethod +import torch +from .stride import Stride +from .utils_convolution import optimizing + + +class PODLayer(torch.nn.Module): + """ + POD layer: it projects the input field on the proper orthogonal + decomposition basis. It needs to be fitted to the data before being used + with the method :meth:`fit`, which invokes the singular value decomposition. + The layer is not trainable. + + .. note:: + All the POD modes are stored in memory, avoiding to recompute them when the rank changes but increasing the memory usage. + """ + + def __init__(self, rank, scale_coefficients=True): + """ + Build the POD layer with the given rank. + + :param int rank: The rank of the POD layer. + :param bool scale_coefficients: If True, the coefficients are scaled + after the projection to have zero mean and unit variance. + """ + super().__init__() + self.__scale_coefficients = scale_coefficients + self._basis = None + self._scaler = None + self._rank = rank + + @property + def rank(self): + """ + The rank of the POD layer. + + :rtype: int + """ + return self._rank + + @rank.setter + def rank(self, value): + if value < 1 or not isinstance(value, int): + raise ValueError('The rank must be positive integer') + + self._rank = value + + @property + def basis(self): + """ + The POD basis. It is a matrix whose columns are the first `self.rank` POD modes. + + :rtype: torch.Tensor + """ + if self._basis is None: + return None + + return self._basis[:self.rank] + + @property + def scaler(self): + """ + The scaler. It is a dictionary with the keys `'mean'` and `'std'` that + store the mean and the standard deviation of the coefficients. + + :rtype: dict + """ + if self._scaler is None: + return + + return {'mean': self._scaler['mean'][:self.rank], + 'std': self._scaler['std'][:self.rank]} + + @property + def scale_coefficients(self): + """ + If True, the coefficients are scaled after the projection to have zero + mean and unit variance. + + :rtype: bool + """ + return self.__scale_coefficients + + def fit(self, X): + """ + Set the POD basis by performing the singular value decomposition of the + given tensor. If `self.scale_coefficients` is True, the coefficients + are scaled after the projection to have zero mean and unit variance. + + :param torch.Tensor X: The tensor to be reduced. + """ + self._fit_pod(X) + + if self.__scale_coefficients: + self._fit_scaler(torch.matmul(self._basis, X.T)) + + def _fit_scaler(self, coeffs): + """ + Private merhod that computes the mean and the standard deviation of the + given coefficients, allowing to scale them to have zero mean and unit + variance. Mean and standard deviation are stored in the private member + `_scaler`. + + :param torch.Tensor coeffs: The coefficients to be scaled. + """ + self._scaler = { + 'std': torch.std(coeffs, dim=1), + 'mean': torch.mean(coeffs, dim=1)} + + def _fit_pod(self, X): + """ + Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`. + + :param torch.Tensor X: The tensor to be reduced. + """ + if X.device.type == 'mps': # svd_lowrank not arailable for mps + self._basis = torch.svd(X.T)[0].T + else: + self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T + + def forward(self, X): + """ + The forward pass of the POD layer. By default it executes the + :meth:`reduce` method, reducing the input tensor to its POD + representation. The POD layer needs to be fitted before being used. + + :param torch.Tensor X: The input tensor to be reduced. + :return: The reduced tensor. + :rtype: torch.Tensor + """ + return self.reduce(X) + + def reduce(self, X): + """ + Reduce the input tensor to its POD representation. The POD layer needs + to be fitted before being used. + + :param torch.Tensor X: The input tensor to be reduced. + :return: The reduced tensor. + :rtype: torch.Tensor + """ + if self._basis is None: + raise RuntimeError( + 'The POD layer needs to be fitted before being used.') + + coeff = torch.matmul(self.basis, X.T) + if coeff.ndim == 1: + coeff = coeff.unsqueeze(1) + + coeff = coeff.T + if self.__scale_coefficients: + coeff = (coeff - self.scaler['mean']) / self.scaler['std'] + + return coeff + + def expand(self, coeff): + """ + Expand the given coefficients to the original space. The POD layer needs + to be fitted before being used. + + :param torch.Tensor coeff: The coefficients to be expanded. + :return: The expanded tensor. + :rtype: torch.Tensor + """ + if self._basis is None: + raise RuntimeError( + 'The POD layer needs to be trained before being used.') + + if self.__scale_coefficients: + coeff = coeff * self.scaler['std'] + self.scaler['mean'] + predicted = torch.matmul(self.basis.T, coeff.T).T + + if predicted.ndim == 1: + predicted = predicted.unsqueeze(0) + + return predicted \ No newline at end of file diff --git a/tests/test_layers/test_pod.py b/tests/test_layers/test_pod.py new file mode 100644 index 0000000..878ae5a --- /dev/null +++ b/tests/test_layers/test_pod.py @@ -0,0 +1,86 @@ +import torch +import pytest + +from pina.model.layers.pod import PODLayer + +x = torch.linspace(-1, 1, 100) +toy_snapshots = torch.vstack([torch.exp(-x**2)*c for c in torch.linspace(0, 1, 10)]) + +def test_constructor(): + pod = PODLayer(2) + pod = PODLayer(2, True) + pod = PODLayer(2, False) + with pytest.raises(TypeError): + pod = PODLayer() + + +@pytest.mark.parametrize("rank", [1, 2, 10]) +def test_fit(rank, scale): + pod = PODLayer(rank, scale) + assert pod._basis == None + assert pod.basis == None + assert pod._scaler == None + assert pod.rank == rank + assert pod.scale_coefficients == scale + +@pytest.mark.parametrize("scale", [True, False]) +@pytest.mark.parametrize("rank", [1, 2, 10]) +def test_fit(rank, scale): + pod = PODLayer(rank, scale) + pod.fit(toy_snapshots) + n_snap = toy_snapshots.shape[0] + dof = toy_snapshots.shape[1] + assert pod.basis.shape == (rank, dof) + assert pod._basis.shape == (n_snap, dof) + if scale is True: + assert pod._scaler['mean'].shape == (n_snap,) + assert pod._scaler['std'].shape == (n_snap,) + assert pod.scaler['mean'].shape == (rank,) + assert pod.scaler['std'].shape == (rank,) + assert pod.scaler['mean'].shape[0] == pod.basis.shape[0] + else: + assert pod._scaler == None + assert pod.scaler == None + +def test_forward(): + pod = PODLayer(1) + pod.fit(toy_snapshots) + c = pod(toy_snapshots) + assert c.shape[0] == toy_snapshots.shape[0] + assert c.shape[1] == pod.rank + torch.testing.assert_close(c.mean(dim=0), torch.zeros(pod.rank)) + torch.testing.assert_close(c.std(dim=0), torch.ones(pod.rank)) + + c = pod(toy_snapshots[0]) + assert c.shape[1] == pod.rank + assert c.shape[0] == 1 + + pod = PODLayer(2, False) + pod.fit(toy_snapshots) + c = pod(toy_snapshots) + torch.testing.assert_close(c, (pod.basis @ toy_snapshots.T).T) + with pytest.raises(AssertionError): + torch.testing.assert_close(c.mean(dim=0), torch.zeros(pod.rank)) + torch.testing.assert_close(c.std(dim=0), torch.ones(pod.rank)) + +@pytest.mark.parametrize("scale", [True, False]) +@pytest.mark.parametrize("rank", [1, 2, 10]) +def test_expand(rank, scale): + pod = PODLayer(rank, scale) + pod.fit(toy_snapshots) + c = pod(toy_snapshots) + torch.testing.assert_close(pod.expand(c), toy_snapshots) + torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0)) + +@pytest.mark.parametrize("scale", [True, False]) +@pytest.mark.parametrize("rank", [1, 2, 10]) +def test_reduce_expand(rank, scale): + pod = PODLayer(rank, scale) + pod.fit(toy_snapshots) + torch.testing.assert_close( + pod.expand(pod.reduce(toy_snapshots)), + toy_snapshots) + torch.testing.assert_close( + pod.expand(pod.reduce(toy_snapshots[0])), + toy_snapshots[0].unsqueeze(0)) + # torch.testing.assert_close(pod.expand(pod.reduce(c[0])), c[0]) \ No newline at end of file