Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -2,42 +2,151 @@ import torch
|
||||
import pytest
|
||||
|
||||
from pina import LabelTensor, Condition
|
||||
from pina.condition import (
|
||||
TensorInputGraphTargetCondition,
|
||||
TensorInputTensorTargetCondition,
|
||||
GraphInputGraphTargetCondition,
|
||||
GraphInputTensorTargetCondition,
|
||||
)
|
||||
from pina.condition import (
|
||||
InputTensorEquationCondition,
|
||||
InputGraphEquationCondition,
|
||||
DomainEquationCondition,
|
||||
)
|
||||
from pina.condition import (
|
||||
TensorDataCondition,
|
||||
GraphDataCondition,
|
||||
)
|
||||
from pina.domain import CartesianDomain
|
||||
from pina.equation.equation_factory import FixedValue
|
||||
from pina.graph import RadiusGraph
|
||||
|
||||
example_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
||||
example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
|
||||
example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
|
||||
|
||||
input_tensor = torch.rand((10,3))
|
||||
target_tensor = torch.rand((10,2))
|
||||
input_lt = LabelTensor(torch.rand((10,3)), ["x", "y", "z"])
|
||||
target_lt = LabelTensor(torch.rand((10,2)), ["a", "b"])
|
||||
|
||||
x = torch.rand(10, 20, 2)
|
||||
pos = torch.rand(10, 20, 2)
|
||||
radius = 0.1
|
||||
input_graph = [
|
||||
RadiusGraph(
|
||||
x=x_,
|
||||
pos=pos_,
|
||||
radius=radius,
|
||||
)
|
||||
for x_, pos_ in zip(x, pos)
|
||||
]
|
||||
target_graph = [
|
||||
RadiusGraph(
|
||||
x=x_,
|
||||
pos=pos_,
|
||||
radius=radius,
|
||||
)
|
||||
for x_, pos_ in zip(x, pos)
|
||||
]
|
||||
|
||||
x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"])
|
||||
pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"])
|
||||
radius = 0.1
|
||||
input_graph_lt = [
|
||||
RadiusGraph(
|
||||
x=x[i],
|
||||
pos=pos[i],
|
||||
radius=radius,
|
||||
)
|
||||
for i in range(len(x))
|
||||
]
|
||||
target_graph_lt = [
|
||||
RadiusGraph(
|
||||
x=x[i],
|
||||
pos=pos[i],
|
||||
radius=radius,
|
||||
)
|
||||
for i in range(len(x))
|
||||
]
|
||||
|
||||
input_single_graph = input_graph[0]
|
||||
target_single_graph = target_graph[0]
|
||||
|
||||
|
||||
def test_init_inputoutput():
|
||||
Condition(input_points=example_input_pts, output_points=example_output_pts)
|
||||
def test_init_input_target():
|
||||
cond = Condition(input=input_tensor, target=target_tensor)
|
||||
assert isinstance(cond, TensorInputTensorTargetCondition)
|
||||
cond = Condition(input=input_tensor, target=target_tensor)
|
||||
assert isinstance(cond, TensorInputTensorTargetCondition)
|
||||
cond = Condition(input=input_tensor, target=target_graph)
|
||||
assert isinstance(cond, TensorInputGraphTargetCondition)
|
||||
cond = Condition(input=input_graph, target=target_tensor)
|
||||
assert isinstance(cond, GraphInputTensorTargetCondition)
|
||||
cond = Condition(input=input_graph, target=target_graph)
|
||||
assert isinstance(cond, GraphInputGraphTargetCondition)
|
||||
|
||||
cond = Condition(input=input_lt, target=input_single_graph)
|
||||
assert isinstance(cond, TensorInputGraphTargetCondition)
|
||||
cond = Condition(input=input_single_graph, target=target_lt)
|
||||
assert isinstance(cond, GraphInputTensorTargetCondition)
|
||||
cond = Condition(input=input_graph, target=target_graph)
|
||||
assert isinstance(cond, GraphInputGraphTargetCondition)
|
||||
cond = Condition(input=input_single_graph, target=target_single_graph)
|
||||
assert isinstance(cond, GraphInputGraphTargetCondition)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_input_pts, example_output_pts)
|
||||
Condition(input_tensor, input_tensor)
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=3., output_points='example')
|
||||
Condition(input=3.0, target="example")
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=example_domain, output_points=example_domain)
|
||||
Condition(input=example_domain, target=example_domain)
|
||||
|
||||
# Test wrong graph condition initialisation
|
||||
input = [input_graph[0], input_graph_lt[0]]
|
||||
target = [target_graph[0], target_graph_lt[0]]
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input=input, target=target)
|
||||
|
||||
input_graph_lt[0].x.labels = ["a", "b"]
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input=input_graph_lt, target=target_graph_lt)
|
||||
input_graph_lt[0].x.labels = ["u", "v"]
|
||||
|
||||
|
||||
test_init_inputoutput()
|
||||
|
||||
|
||||
def test_init_domainfunc():
|
||||
Condition(domain=example_domain, equation=FixedValue(0.0))
|
||||
def test_init_domain_equation():
|
||||
cond = Condition(domain=example_domain, equation=FixedValue(0.0))
|
||||
assert isinstance(cond, DomainEquationCondition)
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_domain, FixedValue(0.0))
|
||||
with pytest.raises(ValueError):
|
||||
Condition(domain=3., equation='example')
|
||||
Condition(domain=3.0, equation="example")
|
||||
with pytest.raises(ValueError):
|
||||
Condition(domain=example_input_pts, equation=example_output_pts)
|
||||
Condition(domain=input_tensor, equation=input_graph)
|
||||
|
||||
|
||||
def test_init_inputfunc():
|
||||
Condition(input_points=example_input_pts, equation=FixedValue(0.0))
|
||||
def test_init_input_equation():
|
||||
cond = Condition(input=input_lt, equation=FixedValue(0.0))
|
||||
assert isinstance(cond, InputTensorEquationCondition)
|
||||
cond = Condition(input=input_graph_lt, equation=FixedValue(0.0))
|
||||
assert isinstance(cond, InputGraphEquationCondition)
|
||||
with pytest.raises(ValueError):
|
||||
cond = Condition(input=input_tensor, equation=FixedValue(0.0))
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_domain, FixedValue(0.0))
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=3., equation='example')
|
||||
Condition(input=3.0, equation="example")
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=example_domain, equation=example_output_pts)
|
||||
Condition(input=example_domain, equation=input_graph)
|
||||
test_init_input_equation()
|
||||
|
||||
def test_init_data_condition():
|
||||
cond = Condition(input=input_lt)
|
||||
assert isinstance(cond, TensorDataCondition)
|
||||
cond = Condition(input=input_tensor)
|
||||
assert isinstance(cond, TensorDataCondition)
|
||||
cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1))
|
||||
assert isinstance(cond, TensorDataCondition)
|
||||
cond = Condition(input=input_graph)
|
||||
assert isinstance(cond, GraphDataCondition)
|
||||
cond = Condition(input=input_graph, conditional_variables=torch.tensor(1))
|
||||
assert isinstance(cond, GraphDataCondition)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user