Layers and Models update PR

* add residual block
* add test conv and residual block
* modify FFN kwargs
This commit is contained in:
Dario Coscia
2023-08-01 19:13:36 +02:00
committed by Nicola Demo
parent 8c16e27ae4
commit 15ecaacb7c
9 changed files with 191 additions and 33 deletions

View File

@@ -1,4 +1,4 @@
from pina.model.layers import ContinuousConv
from pina.model.layers import ContinuousConvBlock
import torch
@@ -70,12 +70,12 @@ x = make_grid(x)
def test_constructor():
model = MLP
conv = ContinuousConv(channel_input,
conv = ContinuousConvBlock(channel_input,
channel_output,
dim,
stride,
model=model)
conv = ContinuousConv(channel_input,
conv = ContinuousConvBlock(channel_input,
channel_output,
dim,
stride,
@@ -86,7 +86,7 @@ def test_forward():
model = MLP
# simple forward
conv = ContinuousConv(channel_input,
conv = ContinuousConvBlock(channel_input,
channel_output,
dim,
stride,
@@ -94,7 +94,7 @@ def test_forward():
conv(x)
# simple forward with optimization
conv = ContinuousConv(channel_input,
conv = ContinuousConvBlock(channel_input,
channel_output,
dim,
stride,
@@ -107,13 +107,13 @@ def test_transpose():
model = MLP
# simple transpose
conv = ContinuousConv(channel_input,
conv = ContinuousConvBlock(channel_input,
channel_output,
dim,
stride,
model=model)
conv2 = ContinuousConv(channel_output,
conv2 = ContinuousConvBlock(channel_output,
channel_input,
dim,
stride,
@@ -122,13 +122,13 @@ def test_transpose():
integrals = conv(x)
conv2.transpose(integrals[..., -1], x)
stride_no_overlap = {"domain": [10, 10],
"start": [0, 0],
"jumps": dim,
"direction": [1, 1.]}
# stride_no_overlap = {"domain": [10, 10],
# "start": [0, 0],
# "jumps": dim,
# "direction": [1, 1.]}
# simple transpose with optimization
# conv = ContinuousConv(channel_input,
## simple transpose with optimization
# conv = ContinuousConvBlock(channel_input,
# channel_output,
# dim,
# stride_no_overlap,
@@ -137,4 +137,4 @@ def test_transpose():
# no_overlap=True)
# integrals = conv(x)
# conv.transpose(integrals[..., -1], x)
# conv.transpose(integrals[..., -1], x)

View File

@@ -0,0 +1,26 @@
from pina.model.layers import ResidualBlock
import torch
def test_constructor():
res_block = ResidualBlock(input_dim=10,
output_dim=3,
hidden_dim=4)
res_block = ResidualBlock(input_dim=10,
output_dim=3,
hidden_dim=4,
spectral_norm=True)
def test_forward():
res_block = ResidualBlock(input_dim=10,
output_dim=3,
hidden_dim=4)
x = torch.rand(size=(80, 10))
y = res_block(x)
assert y.shape[1]==3
assert y.shape[0]==x.shape[0]