New Residual Model and Fix relative import
* Adding Residual MLP * Adding test Residual MLP * Modified relative import Continuous Conv
This commit is contained in:
committed by
Nicola Demo
parent
ba7371f350
commit
17464ceca9
22
tests/test_model/test_residualfnn.py
Normal file
22
tests/test_model/test_residualfnn.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina.model import ResidualFeedForward
|
||||
|
||||
def test_constructor():
|
||||
# simple constructor
|
||||
ResidualFeedForward(input_dimensions=2, output_dimensions=1)
|
||||
|
||||
# wrong transformer nets (not 2)
|
||||
with pytest.raises(ValueError):
|
||||
ResidualFeedForward(input_dimensions=2, output_dimensions=1, transformer_nets=[torch.nn.Linear(2, 20)])
|
||||
|
||||
# wrong transformer nets (not nn.Module)
|
||||
with pytest.raises(ValueError):
|
||||
ResidualFeedForward(input_dimensions=2, output_dimensions=1, transformer_nets=[2, 2])
|
||||
|
||||
def test_forward():
|
||||
x = torch.rand(10, 2)
|
||||
model = ResidualFeedForward(input_dimensions=2, output_dimensions=1)
|
||||
model(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user