remove back compatibility files for version 0.2

This commit is contained in:
giovanni
2025-09-05 11:12:50 +02:00
committed by Giovanni Canali
parent ef3542486c
commit 684d691b78
27 changed files with 0 additions and 120 deletions

View File

@@ -0,0 +1,53 @@
import torch
import pytest
from pina.model.block import PirateNetBlock
data = torch.rand((20, 3))
@pytest.mark.parametrize("inner_size", [10, 20])
def test_constructor(inner_size):
PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
# Should fail if inner_size is negative
with pytest.raises(AssertionError):
PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh)
@pytest.mark.parametrize("inner_size", [10, 20])
def test_forward(inner_size):
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
# Create dummy embedding
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
x = dummy_embedding(data)
# Create dummy U and V tensors
U = torch.rand((data.shape[0], inner_size))
V = torch.rand((data.shape[0], inner_size))
output_ = model(x, U, V)
assert output_.shape == (data.shape[0], inner_size)
@pytest.mark.parametrize("inner_size", [10, 20])
def test_backward(inner_size):
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
data.requires_grad_()
# Create dummy embedding
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
x = dummy_embedding(data)
# Create dummy U and V tensors
U = torch.rand((data.shape[0], inner_size))
V = torch.rand((data.shape[0], inner_size))
output_ = model(x, U, V)
loss = torch.mean(output_)
loss.backward()
assert data.grad.shape == data.shape