Neural Operator fix and addition

* Building FNO for 1D/2D/3D data
* Fixing bug in trunk/branch net in DeepONEt
* Fixing type check bug in spectral conv
* Adding tests for FNO
* Fixing bug in Fourier Layer (conv1d/2d/3d)
This commit is contained in:
Dario Coscia
2023-09-09 22:09:34 +02:00
committed by Nicola Demo
parent 83ecdb0eab
commit 603f56d264
6 changed files with 315 additions and 33 deletions

View File

@@ -0,0 +1,88 @@
import torch
from pina.model import FNO
output_channels = 5
batch_size = 15
resolution = [30, 40, 50]
lifting_dim = 128
def test_constructor():
input_channels = 3
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
# simple constructor
FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=3,
inner_size=60,
n_layers=5)
# simple constructor with n_modes list
FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=[5, 3, 2],
dimensions=3,
inner_size=60,
n_layers=5)
# simple constructor with n_modes list of list
FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=[[5, 3, 2], [5, 3, 2]],
dimensions=3,
inner_size=60,
n_layers=2)
# simple constructor with n_modes list of list
projecting_net = torch.nn.Linear(50, output_channels)
FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=3,
layers=[50, 50])
def test_1d_forward():
input_channels = 1
input_ = torch.rand(batch_size, resolution[0], input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=1,
inner_size=60,
n_layers=2)
out = fno(input_)
assert out.shape == torch.Size([batch_size, resolution[0], output_channels])
def test_2d_forward():
input_channels = 2
input_ = torch.rand(batch_size, resolution[0], resolution[1], input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=2,
inner_size=60,
n_layers=2)
out = fno(input_)
assert out.shape == torch.Size([batch_size, resolution[0], resolution[1], output_channels])
def test_3d_forward():
input_channels = 3
input_ = torch.rand(batch_size, resolution[0], resolution[1], resolution[2], input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=3,
inner_size=60,
n_layers=2)
out = fno(input_)
assert out.shape == torch.Size([batch_size, resolution[0], resolution[1], resolution[2], output_channels])