use LabelTensor, fix minor, docs

This commit is contained in:
Your Name
2022-03-29 18:05:26 +02:00
parent 12f4084d7f
commit 6b001c6c53
19 changed files with 370 additions and 322 deletions

26
tests/test_deeponet.py Normal file
View File

@@ -0,0 +1,26 @@
import torch
import pytest
from pina import LabelTensor
from pina.model import DeepONet, FeedForward as FFN
data = torch.rand((20, 3))
input_vars = ['a', 'b', 'c']
output_vars = ['d']
input_ = LabelTensor(data, input_vars)
def test_constructor():
branch = FFN(input_variables=['a', 'c'], output_variables=20)
trunk = FFN(input_variables=['b'], output_variables=20)
onet = DeepONet(trunk_net=trunk, branch_net=branch,
output_variables=output_vars)
def test_forward():
branch = FFN(input_variables=['a', 'c'], output_variables=10)
trunk = FFN(input_variables=['b'], output_variables=10)
onet = DeepONet(trunk_net=trunk, branch_net=branch,
output_variables=output_vars)
output_ = onet(input_)
assert output_.labels == output_vars

58
tests/test_fnn.py Normal file
View File

@@ -0,0 +1,58 @@
import torch
import pytest
from pina import LabelTensor
from pina.model import FeedForward
class myFeature(torch.nn.Module):
"""
Feature: sin(pi*x)
"""
def __init__(self):
super(myFeature, self).__init__()
def forward(self, x):
return torch.sin(torch.pi * x.extract('a'))
data = torch.rand((20, 3))
input_vars = ['a', 'b', 'c']
output_vars = ['d', 'e']
input_ = LabelTensor(data, input_vars)
def test_constructor():
FeedForward(input_vars, output_vars)
FeedForward(3, 4)
FeedForward(input_vars, output_vars, extra_features=[myFeature()])
FeedForward(input_vars, output_vars, inner_size=10, n_layers=20)
FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2])
FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2],
func=torch.nn.ReLU)
FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2],
func=[torch.nn.ReLU, torch.nn.ReLU, None, torch.nn.Tanh])
def test_constructor_wrong():
with pytest.raises(RuntimeError):
FeedForward(input_vars, output_vars, layers=[10, 20, 5, 2],
func=[torch.nn.ReLU, torch.nn.ReLU])
def test_forward():
fnn = FeedForward(input_vars, output_vars)
output_ = fnn(input_)
assert output_.labels == output_vars
def test_forward2():
dim_in, dim_out = 3, 2
fnn = FeedForward(dim_in, dim_out)
output_ = fnn(input_)
assert output_.shape == (input_.shape[0], dim_out)
def test_forward_features():
fnn = FeedForward(input_vars, output_vars, extra_features=[myFeature()])
output_ = fnn(input_)
assert output_.labels == output_vars

View File

@@ -59,3 +59,22 @@ def test_extract_order():
assert new.labels == label_to_extract
assert new.shape[1] == len(label_to_extract)
assert torch.all(torch.isclose(expected, new))
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bc = tensor_b.append(tensor_c)
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
def test_merge():
tensor = LabelTensor(data, labels)
tensor_a = tensor.extract('a')
tensor_b = tensor.extract('b')
tensor_c = tensor.extract('c')
tensor_bb = tensor_b.append(tensor_b)
assert torch.allclose(tensor_b, tensor.extract(['b', 'c']))