pod layer
This commit is contained in:
177
pina/model/layers/pod.py
Normal file
177
pina/model/layers/pod.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Module for Base Continuous Convolution class."""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from .stride import Stride
|
||||
from .utils_convolution import optimizing
|
||||
|
||||
|
||||
class PODLayer(torch.nn.Module):
|
||||
"""
|
||||
POD layer: it projects the input field on the proper orthogonal
|
||||
decomposition basis. It needs to be fitted to the data before being used
|
||||
with the method :meth:`fit`, which invokes the singular value decomposition.
|
||||
The layer is not trainable.
|
||||
|
||||
.. note::
|
||||
All the POD modes are stored in memory, avoiding to recompute them when the rank changes but increasing the memory usage.
|
||||
"""
|
||||
|
||||
def __init__(self, rank, scale_coefficients=True):
|
||||
"""
|
||||
Build the POD layer with the given rank.
|
||||
|
||||
:param int rank: The rank of the POD layer.
|
||||
:param bool scale_coefficients: If True, the coefficients are scaled
|
||||
after the projection to have zero mean and unit variance.
|
||||
"""
|
||||
super().__init__()
|
||||
self.__scale_coefficients = scale_coefficients
|
||||
self._basis = None
|
||||
self._scaler = None
|
||||
self._rank = rank
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""
|
||||
The rank of the POD layer.
|
||||
|
||||
:rtype: int
|
||||
"""
|
||||
return self._rank
|
||||
|
||||
@rank.setter
|
||||
def rank(self, value):
|
||||
if value < 1 or not isinstance(value, int):
|
||||
raise ValueError('The rank must be positive integer')
|
||||
|
||||
self._rank = value
|
||||
|
||||
@property
|
||||
def basis(self):
|
||||
"""
|
||||
The POD basis. It is a matrix whose columns are the first `self.rank` POD modes.
|
||||
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if self._basis is None:
|
||||
return None
|
||||
|
||||
return self._basis[:self.rank]
|
||||
|
||||
@property
|
||||
def scaler(self):
|
||||
"""
|
||||
The scaler. It is a dictionary with the keys `'mean'` and `'std'` that
|
||||
store the mean and the standard deviation of the coefficients.
|
||||
|
||||
:rtype: dict
|
||||
"""
|
||||
if self._scaler is None:
|
||||
return
|
||||
|
||||
return {'mean': self._scaler['mean'][:self.rank],
|
||||
'std': self._scaler['std'][:self.rank]}
|
||||
|
||||
@property
|
||||
def scale_coefficients(self):
|
||||
"""
|
||||
If True, the coefficients are scaled after the projection to have zero
|
||||
mean and unit variance.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return self.__scale_coefficients
|
||||
|
||||
def fit(self, X):
|
||||
"""
|
||||
Set the POD basis by performing the singular value decomposition of the
|
||||
given tensor. If `self.scale_coefficients` is True, the coefficients
|
||||
are scaled after the projection to have zero mean and unit variance.
|
||||
|
||||
:param torch.Tensor X: The tensor to be reduced.
|
||||
"""
|
||||
self._fit_pod(X)
|
||||
|
||||
if self.__scale_coefficients:
|
||||
self._fit_scaler(torch.matmul(self._basis, X.T))
|
||||
|
||||
def _fit_scaler(self, coeffs):
|
||||
"""
|
||||
Private merhod that computes the mean and the standard deviation of the
|
||||
given coefficients, allowing to scale them to have zero mean and unit
|
||||
variance. Mean and standard deviation are stored in the private member
|
||||
`_scaler`.
|
||||
|
||||
:param torch.Tensor coeffs: The coefficients to be scaled.
|
||||
"""
|
||||
self._scaler = {
|
||||
'std': torch.std(coeffs, dim=1),
|
||||
'mean': torch.mean(coeffs, dim=1)}
|
||||
|
||||
def _fit_pod(self, X):
|
||||
"""
|
||||
Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`.
|
||||
|
||||
:param torch.Tensor X: The tensor to be reduced.
|
||||
"""
|
||||
if X.device.type == 'mps': # svd_lowrank not arailable for mps
|
||||
self._basis = torch.svd(X.T)[0].T
|
||||
else:
|
||||
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
|
||||
|
||||
def forward(self, X):
|
||||
"""
|
||||
The forward pass of the POD layer. By default it executes the
|
||||
:meth:`reduce` method, reducing the input tensor to its POD
|
||||
representation. The POD layer needs to be fitted before being used.
|
||||
|
||||
:param torch.Tensor X: The input tensor to be reduced.
|
||||
:return: The reduced tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.reduce(X)
|
||||
|
||||
def reduce(self, X):
|
||||
"""
|
||||
Reduce the input tensor to its POD representation. The POD layer needs
|
||||
to be fitted before being used.
|
||||
|
||||
:param torch.Tensor X: The input tensor to be reduced.
|
||||
:return: The reduced tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if self._basis is None:
|
||||
raise RuntimeError(
|
||||
'The POD layer needs to be fitted before being used.')
|
||||
|
||||
coeff = torch.matmul(self.basis, X.T)
|
||||
if coeff.ndim == 1:
|
||||
coeff = coeff.unsqueeze(1)
|
||||
|
||||
coeff = coeff.T
|
||||
if self.__scale_coefficients:
|
||||
coeff = (coeff - self.scaler['mean']) / self.scaler['std']
|
||||
|
||||
return coeff
|
||||
|
||||
def expand(self, coeff):
|
||||
"""
|
||||
Expand the given coefficients to the original space. The POD layer needs
|
||||
to be fitted before being used.
|
||||
|
||||
:param torch.Tensor coeff: The coefficients to be expanded.
|
||||
:return: The expanded tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if self._basis is None:
|
||||
raise RuntimeError(
|
||||
'The POD layer needs to be trained before being used.')
|
||||
|
||||
if self.__scale_coefficients:
|
||||
coeff = coeff * self.scaler['std'] + self.scaler['mean']
|
||||
predicted = torch.matmul(self.basis.T, coeff.T).T
|
||||
|
||||
if predicted.ndim == 1:
|
||||
predicted = predicted.unsqueeze(0)
|
||||
|
||||
return predicted
|
||||
86
tests/test_layers/test_pod.py
Normal file
86
tests/test_layers/test_pod.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina.model.layers.pod import PODLayer
|
||||
|
||||
x = torch.linspace(-1, 1, 100)
|
||||
toy_snapshots = torch.vstack([torch.exp(-x**2)*c for c in torch.linspace(0, 1, 10)])
|
||||
|
||||
def test_constructor():
|
||||
pod = PODLayer(2)
|
||||
pod = PODLayer(2, True)
|
||||
pod = PODLayer(2, False)
|
||||
with pytest.raises(TypeError):
|
||||
pod = PODLayer()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rank", [1, 2, 10])
|
||||
def test_fit(rank, scale):
|
||||
pod = PODLayer(rank, scale)
|
||||
assert pod._basis == None
|
||||
assert pod.basis == None
|
||||
assert pod._scaler == None
|
||||
assert pod.rank == rank
|
||||
assert pod.scale_coefficients == scale
|
||||
|
||||
@pytest.mark.parametrize("scale", [True, False])
|
||||
@pytest.mark.parametrize("rank", [1, 2, 10])
|
||||
def test_fit(rank, scale):
|
||||
pod = PODLayer(rank, scale)
|
||||
pod.fit(toy_snapshots)
|
||||
n_snap = toy_snapshots.shape[0]
|
||||
dof = toy_snapshots.shape[1]
|
||||
assert pod.basis.shape == (rank, dof)
|
||||
assert pod._basis.shape == (n_snap, dof)
|
||||
if scale is True:
|
||||
assert pod._scaler['mean'].shape == (n_snap,)
|
||||
assert pod._scaler['std'].shape == (n_snap,)
|
||||
assert pod.scaler['mean'].shape == (rank,)
|
||||
assert pod.scaler['std'].shape == (rank,)
|
||||
assert pod.scaler['mean'].shape[0] == pod.basis.shape[0]
|
||||
else:
|
||||
assert pod._scaler == None
|
||||
assert pod.scaler == None
|
||||
|
||||
def test_forward():
|
||||
pod = PODLayer(1)
|
||||
pod.fit(toy_snapshots)
|
||||
c = pod(toy_snapshots)
|
||||
assert c.shape[0] == toy_snapshots.shape[0]
|
||||
assert c.shape[1] == pod.rank
|
||||
torch.testing.assert_close(c.mean(dim=0), torch.zeros(pod.rank))
|
||||
torch.testing.assert_close(c.std(dim=0), torch.ones(pod.rank))
|
||||
|
||||
c = pod(toy_snapshots[0])
|
||||
assert c.shape[1] == pod.rank
|
||||
assert c.shape[0] == 1
|
||||
|
||||
pod = PODLayer(2, False)
|
||||
pod.fit(toy_snapshots)
|
||||
c = pod(toy_snapshots)
|
||||
torch.testing.assert_close(c, (pod.basis @ toy_snapshots.T).T)
|
||||
with pytest.raises(AssertionError):
|
||||
torch.testing.assert_close(c.mean(dim=0), torch.zeros(pod.rank))
|
||||
torch.testing.assert_close(c.std(dim=0), torch.ones(pod.rank))
|
||||
|
||||
@pytest.mark.parametrize("scale", [True, False])
|
||||
@pytest.mark.parametrize("rank", [1, 2, 10])
|
||||
def test_expand(rank, scale):
|
||||
pod = PODLayer(rank, scale)
|
||||
pod.fit(toy_snapshots)
|
||||
c = pod(toy_snapshots)
|
||||
torch.testing.assert_close(pod.expand(c), toy_snapshots)
|
||||
torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0))
|
||||
|
||||
@pytest.mark.parametrize("scale", [True, False])
|
||||
@pytest.mark.parametrize("rank", [1, 2, 10])
|
||||
def test_reduce_expand(rank, scale):
|
||||
pod = PODLayer(rank, scale)
|
||||
pod.fit(toy_snapshots)
|
||||
torch.testing.assert_close(
|
||||
pod.expand(pod.reduce(toy_snapshots)),
|
||||
toy_snapshots)
|
||||
torch.testing.assert_close(
|
||||
pod.expand(pod.reduce(toy_snapshots[0])),
|
||||
toy_snapshots[0].unsqueeze(0))
|
||||
# torch.testing.assert_close(pod.expand(pod.reduce(c[0])), c[0])
|
||||
Reference in New Issue
Block a user