Files
PINA/tests/test_utils.py
Dario Coscia 7bf7d34d0f Dev Update (#582)
* Fix adaptive refinement (#571)


---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>

* Remove collector

* Fixes

* Fixes

* rm unnecessary comment

* fix advection (#581)

* Fix tutorial .html link (#580)

* fix problem data collection for v0.1 (#584)

* Message Passing Module (#516)

* add deep tensor network block

* add interaction network block

* add radial field network block

* add schnet block

* add equivariant network block

* fix + tests + doc files

* fix egnn + equivariance/invariance tests

Co-authored-by: Dario Coscia <dariocos99@gmail.com>

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>

* add type checker (#527)

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
2025-06-13 17:34:37 +02:00

71 lines
2.3 KiB
Python

import torch
import pytest
from pina import LabelTensor
from pina.utils import merge_tensors, check_consistency, check_positive_integer
from pina.domain import EllipsoidDomain, CartesianDomain, DomainInterface
def test_merge_tensors():
tensor1 = LabelTensor(torch.rand((20, 3)), ["a", "b", "c"])
tensor2 = LabelTensor(torch.zeros((20, 3)), ["d", "e", "f"])
tensor3 = LabelTensor(torch.ones((30, 3)), ["g", "h", "i"])
merged_tensor = merge_tensors((tensor1, tensor2, tensor3))
assert tuple(merged_tensor.labels) == (
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
)
assert merged_tensor.shape == (20 * 20 * 30, 9)
assert torch.all(merged_tensor.extract(("d", "e", "f")) == 0)
assert torch.all(merged_tensor.extract(("g", "h", "i")) == 1)
def test_check_consistency_correct():
ellipsoid1 = EllipsoidDomain({"x": [1, 2], "y": [-2, 1]})
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"])
check_consistency(example_input_pts, torch.Tensor)
check_consistency(CartesianDomain, DomainInterface, subclass=True)
check_consistency(ellipsoid1, DomainInterface)
def test_check_consistency_incorrect():
ellipsoid1 = EllipsoidDomain({"x": [1, 2], "y": [-2, 1]})
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"])
with pytest.raises(ValueError):
check_consistency(example_input_pts, DomainInterface)
with pytest.raises(ValueError):
check_consistency(torch.Tensor, DomainInterface, subclass=True)
with pytest.raises(ValueError):
check_consistency(ellipsoid1, torch.Tensor)
@pytest.mark.parametrize("value", [0, 1, 2, 3, 10])
@pytest.mark.parametrize("strict", [True, False])
def test_check_positive_integer(value, strict):
if value != 0:
check_positive_integer(value, strict=strict)
else:
check_positive_integer(value, strict=False)
# Should fail if value is negative
with pytest.raises(AssertionError):
check_positive_integer(-1, strict=strict)
# Should fail if value is not an integer
with pytest.raises(AssertionError):
check_positive_integer(1.5, strict=strict)
# Should fail if value is not a number
with pytest.raises(AssertionError):
check_positive_integer("string", strict=strict)