Simplify Graph class (#459)
* Simplifying Graph class and adjust tests --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
4c3e305b09
commit
ab6ca78d85
@@ -4,28 +4,31 @@ from pina.condition import InputOutputPointsCondition
|
||||
from pina.problem.zoo.supervised_problem import SupervisedProblem
|
||||
from pina.graph import RadiusGraph
|
||||
|
||||
|
||||
def test_constructor():
|
||||
input_ = torch.rand((100,10))
|
||||
output_ = torch.rand((100,10))
|
||||
input_ = torch.rand((100, 10))
|
||||
output_ = torch.rand((100, 10))
|
||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||
assert isinstance(problem, AbstractProblem)
|
||||
assert hasattr(problem, "conditions")
|
||||
assert isinstance(problem.conditions, dict)
|
||||
assert list(problem.conditions.keys()) == ['data']
|
||||
assert isinstance(problem.conditions['data'], InputOutputPointsCondition)
|
||||
assert list(problem.conditions.keys()) == ["data"]
|
||||
assert isinstance(problem.conditions["data"], InputOutputPointsCondition)
|
||||
|
||||
|
||||
def test_constructor_graph():
|
||||
x = torch.rand((20,100,10))
|
||||
pos = torch.rand((20,100,2))
|
||||
input_ = RadiusGraph(
|
||||
x=x, pos=pos, r=.2, build_edge_attr=True
|
||||
)
|
||||
output_ = torch.rand((100,10))
|
||||
x = torch.rand((20, 100, 10))
|
||||
pos = torch.rand((20, 100, 2))
|
||||
input_ = [
|
||||
RadiusGraph(x=x_, pos=pos_, radius=0.2, edge_attr=True)
|
||||
for x_, pos_ in zip(x, pos)
|
||||
]
|
||||
output_ = torch.rand((100, 10))
|
||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||
assert isinstance(problem, AbstractProblem)
|
||||
assert hasattr(problem, "conditions")
|
||||
assert isinstance(problem.conditions, dict)
|
||||
assert list(problem.conditions.keys()) == ['data']
|
||||
assert isinstance(problem.conditions['data'], InputOutputPointsCondition)
|
||||
assert isinstance(problem.conditions['data'].input_points, list)
|
||||
assert isinstance(problem.conditions['data'].output_points, torch.Tensor)
|
||||
assert list(problem.conditions.keys()) == ["data"]
|
||||
assert isinstance(problem.conditions["data"], InputOutputPointsCondition)
|
||||
assert isinstance(problem.conditions["data"].input_points, list)
|
||||
assert isinstance(problem.conditions["data"].output_points, torch.Tensor)
|
||||
|
||||
Reference in New Issue
Block a user