Add test supervised solver for graph based models (#480)
This commit is contained in:
committed by
Nicola Demo
parent
4177bfbb50
commit
2ae4a94e49
@@ -1,12 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
from torch._dynamo.eval_frame import OptimizedModule
|
||||||
|
from torch_geometric.nn import GCNConv
|
||||||
from pina import Condition, LabelTensor
|
from pina import Condition, LabelTensor
|
||||||
from pina.condition import InputTargetCondition
|
from pina.condition import InputTargetCondition
|
||||||
from pina.problem import AbstractProblem
|
from pina.problem import AbstractProblem
|
||||||
from pina.solver import SupervisedSolver
|
from pina.solver import SupervisedSolver
|
||||||
from pina.model import FeedForward
|
from pina.model import FeedForward
|
||||||
from pina.trainer import Trainer
|
from pina.trainer import Trainer
|
||||||
from torch._dynamo.eval_frame import OptimizedModule
|
from pina.graph import KNNGraph
|
||||||
|
|
||||||
|
|
||||||
class LabelTensorProblem(AbstractProblem):
|
class LabelTensorProblem(AbstractProblem):
|
||||||
@@ -28,9 +30,64 @@ class TensorProblem(AbstractProblem):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.rand((100, 20, 5))
|
||||||
|
pos = torch.rand((100, 20, 2))
|
||||||
|
output_ = torch.rand((100, 20, 1))
|
||||||
|
input_ = [
|
||||||
|
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
|
||||||
|
for x_, pos_ in zip(x, pos)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GraphProblem(AbstractProblem):
|
||||||
|
output_variables = None
|
||||||
|
conditions = {"data": Condition(input=input_, target=output_)}
|
||||||
|
|
||||||
|
|
||||||
|
x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
|
||||||
|
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
|
||||||
|
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
|
||||||
|
input_ = [
|
||||||
|
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
|
||||||
|
for i in range(len(x))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GraphProblemLT(AbstractProblem):
|
||||||
|
output_variables = ["u"]
|
||||||
|
input_variables = ["a", "b", "c", "d", "e"]
|
||||||
|
conditions = {"data": Condition(input=input_, target=output_)}
|
||||||
|
|
||||||
|
|
||||||
model = FeedForward(2, 1)
|
model = FeedForward(2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lift = torch.nn.Linear(5, 10)
|
||||||
|
self.activation = torch.nn.Tanh()
|
||||||
|
self.output = torch.nn.Linear(10, 1)
|
||||||
|
|
||||||
|
self.conv = GCNConv(10, 10)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
|
||||||
|
x = batch.x
|
||||||
|
edge_index = batch.edge_index
|
||||||
|
for _ in range(1):
|
||||||
|
y = self.lift(x)
|
||||||
|
y = self.activation(y)
|
||||||
|
y = self.conv(y, edge_index)
|
||||||
|
y = self.activation(y)
|
||||||
|
y = self.output(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
graph_model = Model()
|
||||||
|
|
||||||
|
|
||||||
def test_constructor():
|
def test_constructor():
|
||||||
SupervisedSolver(problem=TensorProblem(), model=model)
|
SupervisedSolver(problem=TensorProblem(), model=model)
|
||||||
SupervisedSolver(problem=LabelTensorProblem(), model=model)
|
SupervisedSolver(problem=LabelTensorProblem(), model=model)
|
||||||
@@ -59,6 +116,24 @@ def test_solver_train(use_lt, batch_size, compile):
|
|||||||
assert isinstance(solver.model, OptimizedModule)
|
assert isinstance(solver.model, OptimizedModule)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||||
|
@pytest.mark.parametrize("use_lt", [True, False])
|
||||||
|
def test_solver_train_graph(batch_size, use_lt):
|
||||||
|
problem = GraphProblemLT() if use_lt else GraphProblem()
|
||||||
|
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
|
||||||
|
trainer = Trainer(
|
||||||
|
solver=solver,
|
||||||
|
max_epochs=2,
|
||||||
|
accelerator="cpu",
|
||||||
|
batch_size=batch_size,
|
||||||
|
train_size=1.0,
|
||||||
|
test_size=0.0,
|
||||||
|
val_size=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_lt", [True, False])
|
@pytest.mark.parametrize("use_lt", [True, False])
|
||||||
@pytest.mark.parametrize("compile", [True, False])
|
@pytest.mark.parametrize("compile", [True, False])
|
||||||
def test_solver_validation(use_lt, compile):
|
def test_solver_validation(use_lt, compile):
|
||||||
@@ -79,6 +154,24 @@ def test_solver_validation(use_lt, compile):
|
|||||||
assert isinstance(solver.model, OptimizedModule)
|
assert isinstance(solver.model, OptimizedModule)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||||
|
@pytest.mark.parametrize("use_lt", [True, False])
|
||||||
|
def test_solver_validation_graph(batch_size, use_lt):
|
||||||
|
problem = GraphProblemLT() if use_lt else GraphProblem()
|
||||||
|
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
|
||||||
|
trainer = Trainer(
|
||||||
|
solver=solver,
|
||||||
|
max_epochs=2,
|
||||||
|
accelerator="cpu",
|
||||||
|
batch_size=batch_size,
|
||||||
|
train_size=0.9,
|
||||||
|
val_size=0.1,
|
||||||
|
test_size=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_lt", [True, False])
|
@pytest.mark.parametrize("use_lt", [True, False])
|
||||||
@pytest.mark.parametrize("compile", [True, False])
|
@pytest.mark.parametrize("compile", [True, False])
|
||||||
def test_solver_test(use_lt, compile):
|
def test_solver_test(use_lt, compile):
|
||||||
@@ -99,6 +192,24 @@ def test_solver_test(use_lt, compile):
|
|||||||
assert isinstance(solver.model, OptimizedModule)
|
assert isinstance(solver.model, OptimizedModule)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||||
|
@pytest.mark.parametrize("use_lt", [True, False])
|
||||||
|
def test_solver_test_graph(batch_size, use_lt):
|
||||||
|
problem = GraphProblemLT() if use_lt else GraphProblem()
|
||||||
|
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
|
||||||
|
trainer = Trainer(
|
||||||
|
solver=solver,
|
||||||
|
max_epochs=2,
|
||||||
|
accelerator="cpu",
|
||||||
|
batch_size=batch_size,
|
||||||
|
train_size=0.8,
|
||||||
|
val_size=0.1,
|
||||||
|
test_size=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.test()
|
||||||
|
|
||||||
|
|
||||||
def test_train_load_restore():
|
def test_train_load_restore():
|
||||||
dir = "tests/test_solver/tmp/"
|
dir = "tests/test_solver/tmp/"
|
||||||
problem = LabelTensorProblem()
|
problem = LabelTensorProblem()
|
||||||
|
|||||||
Reference in New Issue
Block a user