This commit is contained in:
gc031298
2025-02-21 10:30:17 +01:00
committed by Nicola Demo
parent ff43a7492b
commit ed0a8bd5e7
15 changed files with 5 additions and 5 deletions

View File

@@ -0,0 +1,58 @@
import torch
import pytest
from pina.model.block import LowRankBlock
from pina import LabelTensor
input_dimensions=2
embedding_dimenion=1
rank=4
inner_size=20
n_layers=2
func=torch.nn.Tanh
bias=True
def test_constructor():
LowRankBlock(input_dimensions=input_dimensions,
embedding_dimenion=embedding_dimenion,
rank=rank,
inner_size=inner_size,
n_layers=n_layers,
func=func,
bias=bias)
def test_constructor_wrong():
with pytest.raises(ValueError):
LowRankBlock(input_dimensions=input_dimensions,
embedding_dimenion=embedding_dimenion,
rank=0.5,
inner_size=inner_size,
n_layers=n_layers,
func=func,
bias=bias)
def test_forward():
block = LowRankBlock(input_dimensions=input_dimensions,
embedding_dimenion=embedding_dimenion,
rank=rank,
inner_size=inner_size,
n_layers=n_layers,
func=func,
bias=bias)
data = LabelTensor(torch.rand(10, 30, 3), labels=['x', 'y', 'u'])
block(data.extract('u'), data.extract(['x', 'y']))
def test_backward():
block = LowRankBlock(input_dimensions=input_dimensions,
embedding_dimenion=embedding_dimenion,
rank=rank,
inner_size=inner_size,
n_layers=n_layers,
func=func,
bias=bias)
data = LabelTensor(torch.rand(10, 30, 3), labels=['x', 'y', 'u'])
data.requires_grad_(True)
out = block(data.extract('u'), data.extract(['x', 'y']))
loss = out.mean()
loss.backward()