Add SINDy model (#660)
This commit is contained in:
@@ -106,6 +106,7 @@ Models
|
||||
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
|
||||
PirateNet <model/pirate_network.rst>
|
||||
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>
|
||||
SINDy <model/sindy.rst>
|
||||
|
||||
Blocks
|
||||
-------------
|
||||
|
||||
7
docs/source/_rst/model/sindy.rst
Normal file
7
docs/source/_rst/model/sindy.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
SINDy
|
||||
=======================
|
||||
.. currentmodule:: pina.model.sindy
|
||||
|
||||
.. autoclass:: SINDy
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -15,6 +15,7 @@ __all__ = [
|
||||
"GraphNeuralOperator",
|
||||
"PirateNet",
|
||||
"EquivariantGraphNeuralOperator",
|
||||
"SINDy",
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward, ResidualFeedForward
|
||||
@@ -28,3 +29,4 @@ from .spline import Spline
|
||||
from .graph_neural_operator import GraphNeuralOperator
|
||||
from .pirate_network import PirateNet
|
||||
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
|
||||
from .sindy import SINDy
|
||||
|
||||
102
pina/model/sindy.py
Normal file
102
pina/model/sindy.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Module for the SINDy model class."""
|
||||
|
||||
from typing import Callable
|
||||
import torch
|
||||
from ..utils import check_consistency, check_positive_integer
|
||||
|
||||
|
||||
class SINDy(torch.nn.Module):
|
||||
r"""
|
||||
SINDy model class.
|
||||
|
||||
The Sparse Identification of Nonlinear Dynamics (SINDy) model identifies the
|
||||
governing equations of a dynamical system from data by learning a sparse
|
||||
linear combination of non-linear candidate functions.
|
||||
|
||||
The output of the model is expressed as product of a library matrix and a
|
||||
coefficient matrix:
|
||||
|
||||
.. math::
|
||||
|
||||
\dot{X} = \Theta(X) \Xi
|
||||
|
||||
where:
|
||||
- :math:`X \in \mathbb{R}^{B \times D}` is the input snapshots of the
|
||||
system state. Here, :math:`B` is the batch size and :math:`D` is the
|
||||
number of state variables.
|
||||
- :math:`\Theta(X) \in \mathbb{R}^{B \times L}` is the library matrix
|
||||
obtained by evaluating a set of candidate functions on the input data.
|
||||
Here, :math:`L` is the number of candidate functions in the library.
|
||||
- :math:`\Xi \in \mathbb{R}^{L \times D}` is the learned coefficient
|
||||
matrix that defines the sparse model.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**:
|
||||
Brunton, S.L., Proctor, J.L., and Kutz, J.N. (2016).
|
||||
*Discovering governing equations from data: Sparse identification of
|
||||
non-linear dynamical systems.*
|
||||
Proceedings of the National Academy of Sciences, 113(15), 3932-3937.
|
||||
DOI: `10.1073/pnas.1517384113
|
||||
<https://doi.org/10.1073/pnas.1517384113>`_
|
||||
"""
|
||||
|
||||
def __init__(self, library, output_dimension):
|
||||
"""
|
||||
Initialization of the :class:`SINDy` class.
|
||||
|
||||
:param list[Callable] library: The collection of candidate functions
|
||||
used to construct the library matrix. Each function must accept an
|
||||
input tensor of shape ``[..., D]`` and return a tensor of shape
|
||||
``[..., 1]``.
|
||||
:param int output_dimension: The number of output variables, typically
|
||||
the number of state derivatives. It determines the number of columns
|
||||
in the coefficient matrix.
|
||||
:raises ValueError: If ``library`` is not a list of callables.
|
||||
:raises AssertionError: If ``output_dimension`` is not a positive
|
||||
integer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Check consistency
|
||||
check_positive_integer(output_dimension, strict=True)
|
||||
check_consistency(library, Callable)
|
||||
if not isinstance(library, list):
|
||||
raise ValueError("`library` must be a list of callables.")
|
||||
|
||||
# Initialization
|
||||
self._library = library
|
||||
self._coefficients = torch.nn.Parameter(
|
||||
torch.zeros(len(library), output_dimension)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the :class:`SINDy` model.
|
||||
|
||||
:param torch.Tensor x: The input batch of state variables.
|
||||
:return: The predicted time derivatives of the state variables.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
theta = torch.stack([f(x) for f in self.library], dim=-2)
|
||||
return torch.einsum("...li , lo -> ...o", theta, self.coefficients)
|
||||
|
||||
@property
|
||||
def library(self):
|
||||
"""
|
||||
The library of candidate functions.
|
||||
|
||||
:return: The library.
|
||||
:rtype: list[Callable]
|
||||
"""
|
||||
return self._library
|
||||
|
||||
@property
|
||||
def coefficients(self):
|
||||
"""
|
||||
The coefficients of the model.
|
||||
|
||||
:return: The coefficients.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._coefficients
|
||||
55
tests/test_model/test_sindy.py
Normal file
55
tests/test_model/test_sindy.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina.model import SINDy
|
||||
|
||||
# Define a simple library of candidate functions and some test data
|
||||
library = [lambda x: torch.pow(x, 2), lambda x: torch.sin(x)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
|
||||
def test_constructor(data):
|
||||
SINDy(library, data.shape[-1])
|
||||
|
||||
# Should fail if output_dimension is not a positive integer
|
||||
with pytest.raises(AssertionError):
|
||||
SINDy(library, "not_int")
|
||||
with pytest.raises(AssertionError):
|
||||
SINDy(library, -1)
|
||||
|
||||
# Should fail if library is not a list
|
||||
with pytest.raises(ValueError):
|
||||
SINDy(lambda x: torch.pow(x, 2), 3)
|
||||
|
||||
# Should fail if library is not a list of callables
|
||||
with pytest.raises(ValueError):
|
||||
SINDy([1, 2, 3], 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
|
||||
def test_forward(data):
|
||||
|
||||
# Define model
|
||||
model = SINDy(library, data.shape[-1])
|
||||
with torch.no_grad():
|
||||
model.coefficients.data.fill_(1.0)
|
||||
|
||||
# Evaluate model
|
||||
output_ = model(data)
|
||||
vals = data.pow(2) + torch.sin(data)
|
||||
|
||||
print(data.shape, output_.shape, vals.shape)
|
||||
|
||||
assert output_.shape == data.shape
|
||||
assert torch.allclose(output_, vals, atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
|
||||
def test_backward(data):
|
||||
|
||||
# Define and evaluate model
|
||||
model = SINDy(library, data.shape[-1])
|
||||
output_ = model(data.requires_grad_())
|
||||
|
||||
loss = output_.mean()
|
||||
loss.backward()
|
||||
assert data.grad.shape == data.shape
|
||||
Reference in New Issue
Block a user