38 lines
1.0 KiB
Python
38 lines
1.0 KiB
Python
import torch
|
|
import pytest
|
|
|
|
from pina.model.network import Network
|
|
from pina.model import FeedForward
|
|
from pina import LabelTensor
|
|
|
|
data = torch.rand((20, 3))
|
|
data_lt = LabelTensor(data, ['x', 'y', 'z'])
|
|
input_dim = 3
|
|
output_dim = 4
|
|
torchmodel = FeedForward(input_dim, output_dim)
|
|
extra_feat = []
|
|
|
|
|
|
def test_constructor():
|
|
Network(model=torchmodel,
|
|
input_variables=['x', 'y', 'z'],
|
|
output_variables=['a', 'b', 'c', 'd'],
|
|
extra_features=None)
|
|
|
|
def test_forward():
|
|
net = Network(model=torchmodel,
|
|
input_variables=['x', 'y', 'z'],
|
|
output_variables=['a', 'b', 'c', 'd'],
|
|
extra_features=None)
|
|
out = net.torchmodel(data)
|
|
out_lt = net(data_lt)
|
|
assert isinstance(out, torch.Tensor)
|
|
assert isinstance(out_lt, LabelTensor)
|
|
assert out.shape == (20, 4)
|
|
assert out_lt.shape == (20, 4)
|
|
assert torch.allclose(out_lt, out)
|
|
assert out_lt.labels == ['a', 'b', 'c', 'd']
|
|
|
|
with pytest.raises(AssertionError):
|
|
net(data)
|