Dev Update (#582)
* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
@@ -4,6 +4,11 @@ from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||
from pina import LabelTensor
|
||||
from pina.domain import Union
|
||||
from pina.domain import CartesianDomain
|
||||
from pina.condition import (
|
||||
Condition,
|
||||
InputTargetCondition,
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
|
||||
def test_discretise_domain():
|
||||
@@ -45,6 +50,24 @@ def test_variables_correct_order_sampling():
|
||||
)
|
||||
|
||||
|
||||
def test_input_pts():
|
||||
n = 10
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(n, "grid")
|
||||
assert sorted(list(poisson_problem.input_pts.keys())) == sorted(
|
||||
list(poisson_problem.conditions.keys())
|
||||
)
|
||||
|
||||
|
||||
def test_collected_data():
|
||||
n = 10
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(n, "grid")
|
||||
assert sorted(list(poisson_problem.collected_data.keys())) == sorted(
|
||||
list(poisson_problem.conditions.keys())
|
||||
)
|
||||
|
||||
|
||||
def test_add_points():
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(0, "random", domains=["D"])
|
||||
@@ -84,3 +107,23 @@ def test_wrong_custom_sampling_logic(mode):
|
||||
}
|
||||
with pytest.raises(RuntimeError):
|
||||
poisson_problem.discretise_domain(sample_rules=sampling_rules)
|
||||
|
||||
|
||||
def test_aggregate_data():
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.conditions["data"] = Condition(
|
||||
input=LabelTensor(torch.tensor([[0.0, 1.0]]), labels=["x", "y"]),
|
||||
target=LabelTensor(torch.tensor([[0.0]]), labels=["u"]),
|
||||
)
|
||||
poisson_problem.discretise_domain(0, "random", domains="all")
|
||||
poisson_problem.collect_data()
|
||||
assert isinstance(poisson_problem.collected_data, dict)
|
||||
for name, conditions in poisson_problem.conditions.items():
|
||||
assert name in poisson_problem.collected_data.keys()
|
||||
if isinstance(conditions, InputTargetCondition):
|
||||
assert "input" in poisson_problem.collected_data[name].keys()
|
||||
assert "target" in poisson_problem.collected_data[name].keys()
|
||||
elif isinstance(conditions, DomainEquationCondition):
|
||||
assert "input" in poisson_problem.collected_data[name].keys()
|
||||
assert "target" not in poisson_problem.collected_data[name].keys()
|
||||
assert "equation" in poisson_problem.collected_data[name].keys()
|
||||
|
||||
Reference in New Issue
Block a user