remove back compatibility files for version 0.2
This commit is contained in:
committed by
Giovanni Canali
parent
ef3542486c
commit
684d691b78
53
tests/test_block/test_pirate_network_block.py
Normal file
53
tests/test_block/test_pirate_network_block.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina.model.block import PirateNetBlock
|
||||
|
||||
data = torch.rand((20, 3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
def test_constructor(inner_size):
|
||||
|
||||
PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
|
||||
|
||||
# Should fail if inner_size is negative
|
||||
with pytest.raises(AssertionError):
|
||||
PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
def test_forward(inner_size):
|
||||
|
||||
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
|
||||
|
||||
# Create dummy embedding
|
||||
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
|
||||
x = dummy_embedding(data)
|
||||
|
||||
# Create dummy U and V tensors
|
||||
U = torch.rand((data.shape[0], inner_size))
|
||||
V = torch.rand((data.shape[0], inner_size))
|
||||
|
||||
output_ = model(x, U, V)
|
||||
assert output_.shape == (data.shape[0], inner_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inner_size", [10, 20])
|
||||
def test_backward(inner_size):
|
||||
|
||||
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
|
||||
data.requires_grad_()
|
||||
|
||||
# Create dummy embedding
|
||||
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
|
||||
x = dummy_embedding(data)
|
||||
|
||||
# Create dummy U and V tensors
|
||||
U = torch.rand((data.shape[0], inner_size))
|
||||
V = torch.rand((data.shape[0], inner_size))
|
||||
|
||||
output_ = model(x, U, V)
|
||||
|
||||
loss = torch.mean(output_)
|
||||
loss.backward()
|
||||
assert data.grad.shape == data.shape
|
||||
Reference in New Issue
Block a user