Files
PINA/tests/test_model/test_fno.py
Dario Coscia 9cae9a438f Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
2025-03-19 17:46:35 +01:00

156 lines
4.8 KiB
Python

import torch
from pina.model import FNO
output_channels = 5
batch_size = 4
resolution = [4, 6, 8]
lifting_dim = 24
def test_constructor():
input_channels = 3
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
# simple constructor
FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=3,
inner_size=60,
n_layers=5)
# simple constructor with n_modes list
FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=[5, 3, 2],
dimensions=3,
inner_size=60,
n_layers=5)
# simple constructor with n_modes list of list
FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=[[5, 3, 2], [5, 3, 2]],
dimensions=3,
inner_size=60,
n_layers=2)
# simple constructor with n_modes list of list
projecting_net = torch.nn.Linear(50, output_channels)
FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=3,
layers=[50, 50])
def test_1d_forward():
input_channels = 1
input_ = torch.rand(batch_size, resolution[0], input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=1,
inner_size=60,
n_layers=2)
out = fno(input_)
assert out.shape == torch.Size([batch_size, resolution[0], output_channels])
def test_1d_backward():
input_channels = 1
input_ = torch.rand(batch_size, resolution[0], input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=1,
inner_size=60,
n_layers=2)
input_.requires_grad = True
out = fno(input_)
l = torch.mean(out)
l.backward()
assert input_.grad.shape == torch.Size([batch_size, resolution[0], input_channels])
def test_2d_forward():
input_channels = 2
input_ = torch.rand(batch_size, resolution[0], resolution[1],
input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=2,
inner_size=60,
n_layers=2)
out = fno(input_)
assert out.shape == torch.Size(
[batch_size, resolution[0], resolution[1], output_channels])
def test_2d_backward():
input_channels = 2
input_ = torch.rand(batch_size, resolution[0], resolution[1],
input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=2,
inner_size=60,
n_layers=2)
input_.requires_grad = True
out = fno(input_)
l = torch.mean(out)
l.backward()
assert input_.grad.shape == torch.Size([
batch_size, resolution[0], resolution[1], input_channels
])
def test_3d_forward():
input_channels = 3
input_ = torch.rand(batch_size, resolution[0], resolution[1], resolution[2],
input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=3,
inner_size=60,
n_layers=2)
out = fno(input_)
assert out.shape == torch.Size([
batch_size, resolution[0], resolution[1], resolution[2], output_channels
])
def test_3d_backward():
input_channels = 3
input_ = torch.rand(batch_size, resolution[0], resolution[1], resolution[2],
input_channels)
lifting_net = torch.nn.Linear(input_channels, lifting_dim)
projecting_net = torch.nn.Linear(60, output_channels)
fno = FNO(lifting_net=lifting_net,
projecting_net=projecting_net,
n_modes=5,
dimensions=3,
inner_size=60,
n_layers=2)
input_.requires_grad = True
out = fno(input_)
l = torch.mean(out)
l.backward()
assert input_.grad.shape == torch.Size([
batch_size, resolution[0], resolution[1], resolution[2], input_channels
])