* Building FNO for 1D/2D/3D data * Fixing bug in trunk/branch net in DeepONEt * Fixing type check bug in spectral conv * Adding tests for FNO * Fixing bug in Fourier Layer (conv1d/2d/3d)
88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
import torch
|
|
from pina.model import FNO
|
|
|
|
|
|
output_channels = 5
|
|
batch_size = 15
|
|
resolution = [30, 40, 50]
|
|
lifting_dim = 128
|
|
|
|
|
|
def test_constructor():
|
|
input_channels = 3
|
|
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
|
|
projecting_net = torch.nn.Linear(60, output_channels)
|
|
|
|
# simple constructor
|
|
FNO(lifting_net=lifting_net,
|
|
projecting_net=projecting_net,
|
|
n_modes=5,
|
|
dimensions=3,
|
|
inner_size=60,
|
|
n_layers=5)
|
|
|
|
# simple constructor with n_modes list
|
|
FNO(lifting_net=lifting_net,
|
|
projecting_net=projecting_net,
|
|
n_modes=[5, 3, 2],
|
|
dimensions=3,
|
|
inner_size=60,
|
|
n_layers=5)
|
|
|
|
# simple constructor with n_modes list of list
|
|
FNO(lifting_net=lifting_net,
|
|
projecting_net=projecting_net,
|
|
n_modes=[[5, 3, 2], [5, 3, 2]],
|
|
dimensions=3,
|
|
inner_size=60,
|
|
n_layers=2)
|
|
|
|
# simple constructor with n_modes list of list
|
|
projecting_net = torch.nn.Linear(50, output_channels)
|
|
FNO(lifting_net=lifting_net,
|
|
projecting_net=projecting_net,
|
|
n_modes=5,
|
|
dimensions=3,
|
|
layers=[50, 50])
|
|
|
|
def test_1d_forward():
|
|
input_channels = 1
|
|
input_ = torch.rand(batch_size, resolution[0], input_channels)
|
|
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
|
|
projecting_net = torch.nn.Linear(60, output_channels)
|
|
fno = FNO(lifting_net=lifting_net,
|
|
projecting_net=projecting_net,
|
|
n_modes=5,
|
|
dimensions=1,
|
|
inner_size=60,
|
|
n_layers=2)
|
|
out = fno(input_)
|
|
assert out.shape == torch.Size([batch_size, resolution[0], output_channels])
|
|
|
|
def test_2d_forward():
|
|
input_channels = 2
|
|
input_ = torch.rand(batch_size, resolution[0], resolution[1], input_channels)
|
|
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
|
|
projecting_net = torch.nn.Linear(60, output_channels)
|
|
fno = FNO(lifting_net=lifting_net,
|
|
projecting_net=projecting_net,
|
|
n_modes=5,
|
|
dimensions=2,
|
|
inner_size=60,
|
|
n_layers=2)
|
|
out = fno(input_)
|
|
assert out.shape == torch.Size([batch_size, resolution[0], resolution[1], output_channels])
|
|
|
|
def test_3d_forward():
|
|
input_channels = 3
|
|
input_ = torch.rand(batch_size, resolution[0], resolution[1], resolution[2], input_channels)
|
|
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
|
|
projecting_net = torch.nn.Linear(60, output_channels)
|
|
fno = FNO(lifting_net=lifting_net,
|
|
projecting_net=projecting_net,
|
|
n_modes=5,
|
|
dimensions=3,
|
|
inner_size=60,
|
|
n_layers=2)
|
|
out = fno(input_)
|
|
assert out.shape == torch.Size([batch_size, resolution[0], resolution[1], resolution[2], output_channels]) |