fix test PeriodicBoundaryEmbedding (#257)

* fix test PeriodicBoundaryEmbedding
* fix tests
This commit is contained in:
Dario Coscia
2024-03-04 11:50:27 +01:00
committed by GitHub
parent 46b366a461
commit 15136e13f8

View File

@@ -4,11 +4,15 @@ import pytest
from pina.model.layers import PeriodicBoundaryEmbedding from pina.model.layers import PeriodicBoundaryEmbedding
from pina import LabelTensor from pina import LabelTensor
# test tolerance
tol = 1e-6
def check_same_columns(tensor): def check_same_columns(tensor):
# Get the first column # Get the first column and compute residual
first_column = tensor[0] residual = tensor - tensor[0]
zeros = torch.zeros_like(residual)
# Compare each column with the first column # Compare each column with the first column
all_same = torch.allclose(tensor, first_column) all_same = torch.allclose(input=residual,other=zeros,atol=tol)
return all_same return all_same
def grad(u, x): def grad(u, x):
@@ -57,43 +61,43 @@ def test_forward_same_period(input_dimension, period):
def test_forward_same_period_labels(): # def test_forward_same_period_labels():
func = torch.nn.Sequential( # func = torch.nn.Sequential(
PeriodicBoundaryEmbedding(input_dimension=2, # PeriodicBoundaryEmbedding(input_dimension=2,
output_dimension=60, periods={'x':1, 'y':2}), # output_dimension=60, periods={'x':1, 'y':2}),
torch.nn.Tanh(), # torch.nn.Tanh(),
torch.nn.Linear(60, 60), # torch.nn.Linear(60, 60),
torch.nn.Tanh(), # torch.nn.Tanh(),
torch.nn.Linear(60, 1) # torch.nn.Linear(60, 1)
) # )
# coordinates # # coordinates
tensor = torch.tensor([[0., 0.], [0., 2.], [1., 0.], [1., 2.]]) # tensor = torch.tensor([[0., 0.], [0., 2.], [1., 0.], [1., 2.]])
with pytest.raises(RuntimeError): # with pytest.raises(RuntimeError):
func(tensor) # func(tensor)
tensor = tensor.as_subclass(LabelTensor) # tensor = tensor.as_subclass(LabelTensor)
tensor.labels = ['x', 'y'] # tensor.labels = ['x', 'y']
tensor.requires_grad = True # tensor.requires_grad = True
# output # # output
f = func(tensor) # f = func(tensor)
assert check_same_columns(f) # assert check_same_columns(f)
def test_forward_same_period_index(): # def test_forward_same_period_index():
func = torch.nn.Sequential( # func = torch.nn.Sequential(
PeriodicBoundaryEmbedding(input_dimension=2, # PeriodicBoundaryEmbedding(input_dimension=2,
output_dimension=60, periods={0:1, 1:2}), # output_dimension=60, periods={0:1, 1:2}),
torch.nn.Tanh(), # torch.nn.Tanh(),
torch.nn.Linear(60, 60), # torch.nn.Linear(60, 60),
torch.nn.Tanh(), # torch.nn.Tanh(),
torch.nn.Linear(60, 1) # torch.nn.Linear(60, 1)
) # )
# coordinates # # coordinates
tensor = torch.tensor([[0., 0.], [0., 2.], [1., 0.], [1., 2.]]) # tensor = torch.tensor([[0., 0.], [0., 2.], [1., 0.], [1., 2.]])
tensor.requires_grad = True # tensor.requires_grad = True
# output # # output
f = func(tensor) # f = func(tensor)
assert check_same_columns(f) # assert check_same_columns(f)
tensor = tensor.as_subclass(LabelTensor) # tensor = tensor.as_subclass(LabelTensor)
tensor.labels = ['x', 'y'] # tensor.labels = ['x', 'y']
# output # # output
f = func(tensor) # f = func(tensor)
assert check_same_columns(f) # assert check_same_columns(f)