diff --git a/lightning_logs/version_0/hparams.yaml b/lightning_logs/version_0/hparams.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/lightning_logs/version_0/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_1/checkpoints/epoch=4-step=5.ckpt b/lightning_logs/version_1/checkpoints/epoch=4-step=5.ckpt new file mode 100644 index 0000000..d3ca30f Binary files /dev/null and b/lightning_logs/version_1/checkpoints/epoch=4-step=5.ckpt differ diff --git a/lightning_logs/version_1/hparams.yaml b/lightning_logs/version_1/hparams.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/lightning_logs/version_1/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_2/checkpoints/epoch=14-step=15.ckpt b/lightning_logs/version_2/checkpoints/epoch=14-step=15.ckpt new file mode 100644 index 0000000..57e5e5b Binary files /dev/null and b/lightning_logs/version_2/checkpoints/epoch=14-step=15.ckpt differ diff --git a/lightning_logs/version_2/hparams.yaml b/lightning_logs/version_2/hparams.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/lightning_logs/version_2/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_3/checkpoints/epoch=4-step=5.ckpt b/lightning_logs/version_3/checkpoints/epoch=4-step=5.ckpt new file mode 100644 index 0000000..fde95e8 Binary files /dev/null and b/lightning_logs/version_3/checkpoints/epoch=4-step=5.ckpt differ diff --git a/lightning_logs/version_3/hparams.yaml b/lightning_logs/version_3/hparams.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/lightning_logs/version_3/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 7038c6d..71d781c 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -1,7 +1,11 @@ __all__ = [ 'ContinuousConvBlock', - 'ResidualBlock' + 'ResidualBlock', + 'SpectralConvBlock1D', + 'SpectralConvBlock2D', + 'SpectralConvBlock3D' ] from .convolution_2d import ContinuousConvBlock from .residual import ResidualBlock +from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D diff --git a/pina/model/layers/spectral.py b/pina/model/layers/spectral.py index e685164..465ff18 100644 --- a/pina/model/layers/spectral.py +++ b/pina/model/layers/spectral.py @@ -1,16 +1,326 @@ import torch import torch.nn as nn from ...utils import check_consistency +import warnings - -class SpectralConvBlock(nn.Module): +######## 1D Spectral Convolution ########### +class SpectralConvBlock1D(nn.Module): """ - Implementation of spectral convolution block. + Implementation of Spectral Convolution Block for one + dimensional tensor. """ - def __init__(self): + def __init__(self, input_numb_fields, output_numb_fields, n_modes): + """ + TODO + + :param input_numb_fields: _description_ + :type input_numb_fields: _type_ + :param output_numb_fields: _description_ + :type output_numb_fields: _type_ + :param n_modes: _description_ + :type n_modes: _type_ + """ super().__init__() + # check type consistency + check_consistency(input_numb_fields, int) + check_consistency(output_numb_fields, int) + + # assign variables + self._modes = n_modes + self._input_channels = input_numb_fields + self._output_channels = output_numb_fields + + # scaling factor + scale = (1. / (self._input_channels * self._output_channels)) + self._weights = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes, + dtype=torch.cfloat)) + + def _compute_mult1d(self, input, weights): + """ + Compute the matrix multiplication of the input + with the linear kernel weights. + + :param input: The input tensor, expect of size + [batch, input_numb_fields, x]. + :type input: torch.Tensor + :param weights: The kernel weights, expect of + size [input_numb_fields, output_numb_fields, x]. + :type weights: torch.Tensor + :return: The matrix multiplication of the input + with the linear kernel weights. + :rtype: torch.Tensor + """ + return torch.einsum("bix,iox->box", input, weights) def forward(self, x): - pass \ No newline at end of file + """ + Forward computation for Spectral Convolution. + + :param x: The input tensor, expect of size + [batch, input_numb_fields, x]. + :type x: torch.Tensor + :return: The output tensor obtained from the + spectral convolution of size [batch, output_numb_fields, x]. + :rtype: torch.Tensor + """ + batch_size = x.shape[0] + + # if x.shape[-1] // 2 + 1 < self._modes: + # raise RuntimeError('Number of modes is too high, decrease number of modes.') + + # Compute Fourier transform of the input + x_ft = torch.fft.rfft(x) + + # Multiply relevant Fourier modes + out_ft = torch.zeros(batch_size, + self._output_channels, + x.size(-1) // 2 + 1, + device=x.device, + dtype=torch.cfloat) + out_ft[:, :, :self._modes] = self._compute_mult1d(x_ft[:, :, :self._modes], self._weights) + + # Return to physical space + return torch.fft.irfft(out_ft, n=x.size(-1)) + + +######## 2D Spectral Convolution ########### +class SpectralConvBlock2D(nn.Module): + """ + Implementation of spectral convolution block for two + dimensional tensor. + """ + + def __init__(self, input_numb_fields, output_numb_fields, n_modes): + super().__init__() + + # check type consistency + check_consistency(input_numb_fields, int) + check_consistency(output_numb_fields, int) + if not isinstance(n_modes, (tuple, list)): + raise ValueError('expected n_modes to be a list or tuple of len two, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + if len(n_modes) != 2: + raise ValueError('expected n_modes to be a list or tuple of len two, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + check_consistency(n_modes, int) + + + # assign variables + self._modes = n_modes + self._input_channels = input_numb_fields + self._output_channels = output_numb_fields + + # scaling factor + scale = (1. / (self._input_channels * self._output_channels)) + self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + dtype=torch.cfloat)) + self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + dtype=torch.cfloat)) + + def _compute_mult2d(self, input, weights): + """ + Compute the matrix multiplication of the input + with the linear kernel weights. + + :param input: The input tensor, expect of size + [batch, input_numb_fields, x, y]. + :type input: torch.Tensor + :param weights: The kernel weights, expect of + size [input_numb_fields, output_numb_fields, x, y]. + :type weights: torch.Tensor + :return: The matrix multiplication of the input + with the linear kernel weights. + :rtype: torch.Tensor + """ + return torch.einsum("bixy,ioxy->boxy", input, weights) + + def forward(self, x): + """ + Forward computation for Spectral Convolution. + + :param x: The input tensor, expect of size + [batch, input_numb_fields, x]. + :type x: torch.Tensor + :return: The output tensor obtained from the + spectral convolution of size [batch, output_numb_fields, x]. + :rtype: torch.Tensor + """ + + batch_size = x.shape[0] + + # Compute Fourier transform of the input + x_ft = torch.fft.rfft2(x) + + # Multiply relevant Fourier modes + out_ft = torch.zeros(batch_size, + self._output_channels, + x.size(-2), + x.size(-1)//2 + 1, + device=x.device, + dtype=torch.cfloat) + out_ft[:, :, :self._modes[0], :self._modes[1]] = self._compute_mult2d(x_ft[:, :, :self._modes[0], :self._modes[1]], + self._weights1) + out_ft[:, :, -self._modes[0]:, :self._modes[1]:] = self._compute_mult2d(x_ft[:, :, -self._modes[0]:, :self._modes[1]], + self._weights2) + + # Return to physical space + return torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) + + +######## 2D Spectral Convolution ########### +class SpectralConvBlock3D(nn.Module): + """ + Implementation of spectral convolution block for two + dimensional tensor. + """ + + def __init__(self, input_numb_fields, output_numb_fields, n_modes): + """ + TODO + + :param input_numb_fields: _description_ + :type input_numb_fields: _type_ + :param output_numb_fields: _description_ + :type output_numb_fields: _type_ + :param n_modes: _description_ + :type n_modes: _type_ + :raises ValueError: _description_ + :raises ValueError: _description_ + """ + super().__init__() + + # check type consistency + check_consistency(input_numb_fields, int) + check_consistency(output_numb_fields, int) + if not isinstance(n_modes, (tuple, list)): + raise ValueError('expected n_modes to be a list or tuple of len three, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + if len(n_modes) != 3: + raise ValueError('expected n_modes to be a list or tuple of len three, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + check_consistency(n_modes, int) + + # assign variables + self._modes = n_modes + self._input_channels = input_numb_fields + self._output_channels = output_numb_fields + + # scaling factor + scale = (1. / (self._input_channels * self._output_channels)) + self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + self._modes[2], + dtype=torch.cfloat)) + self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + self._modes[2], + dtype=torch.cfloat)) + self._weights3 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + self._modes[2], + dtype=torch.cfloat)) + self._weights4 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + self._modes[2], + dtype=torch.cfloat)) + + def _compute_mult3d(self, input, weights): + """ + Compute the matrix multiplication of the input + with the linear kernel weights. + + :param input: The input tensor, expect of size + [batch, input_numb_fields, x, y]. + :type input: torch.Tensor + :param weights: The kernel weights, expect of + size [input_numb_fields, output_numb_fields, x, y]. + :type weights: torch.Tensor + :return: The matrix multiplication of the input + with the linear kernel weights. + :rtype: torch.Tensor + """ + return torch.einsum("bixyz,ioxyz->boxyz", input, weights) + + def forward(self, x): + """ + Forward computation for Spectral Convolution. + + :param x: The input tensor, expect of size + [batch, input_numb_fields, x]. + :type x: torch.Tensor + :return: The output tensor obtained from the + spectral convolution of size [batch, output_numb_fields, x]. + :rtype: torch.Tensor + """ + + batch_size = x.shape[0] + + # Compute Fourier transform of the input + x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1]) + + # Multiply relevant Fourier modes + out_ft = torch.zeros(batch_size, + self._output_channels, + x.size(-3), + x.size(-2), + x.size(-1)//2 + 1, + device=x.device, + dtype=torch.cfloat) + + slice0 = (slice(None), + slice(None), + slice(self._modes[0]), + slice(self._modes[1]), + slice(self._modes[2]), + ) + out_ft[slice0] = self._compute_mult3d(x_ft[slice0], self._weights1) + + slice1 = (slice(None), + slice(None), + slice(self._modes[0]), + slice(-self._modes[1], None), + slice(self._modes[2]), + ) + out_ft[slice1] = self._compute_mult3d(x_ft[slice1], self._weights2) + + slice2 = (slice(None), + slice(None), + slice(-self._modes[0], None), + slice(self._modes[1]), + slice(self._modes[2]), + ) + out_ft[slice2] = self._compute_mult3d(x_ft[slice2], self._weights3) + + slice3 = (slice(None), + slice(None), + slice(-self._modes[0], None), + slice(-self._modes[1], None), + slice(self._modes[2]), + ) + out_ft[slice3] = self._compute_mult3d(x_ft[slice3], self._weights4) + + # Return to physical space + return torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) + diff --git a/tests/test_layers/test_spectral_conv.py b/tests/test_layers/test_spectral_conv.py new file mode 100644 index 0000000..db9f399 --- /dev/null +++ b/tests/test_layers/test_spectral_conv.py @@ -0,0 +1,43 @@ +from pina.model.layers import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D +import torch + +input_numb_fields = 3 +output_numb_fields = 4 +batch = 5 + +def test_constructor_1d(): + SpectralConvBlock1D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=5) + +def test_forward_1d(): + sconv = SpectralConvBlock1D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=4) + x = torch.rand(batch, input_numb_fields, 10) + sconv(x) + + +def test_constructor_2d(): + SpectralConvBlock2D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4]) + +def test_forward_2d(): + sconv = SpectralConvBlock2D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4]) + x = torch.rand(batch, input_numb_fields, 10, 10) + sconv(x) + +def test_constructor_3d(): + SpectralConvBlock3D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4, 4]) + +def test_forward_3d(): + sconv = SpectralConvBlock3D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4, 4]) + x = torch.rand(batch, input_numb_fields, 10, 10, 10) + sconv(x)