Add test supervised solver for graph based models (#480)
This commit is contained in:
committed by
Nicola Demo
parent
4177bfbb50
commit
2ae4a94e49
@@ -1,12 +1,14 @@
|
||||
import torch
|
||||
import pytest
|
||||
from torch._dynamo.eval_frame import OptimizedModule
|
||||
from torch_geometric.nn import GCNConv
|
||||
from pina import Condition, LabelTensor
|
||||
from pina.condition import InputTargetCondition
|
||||
from pina.problem import AbstractProblem
|
||||
from pina.solver import SupervisedSolver
|
||||
from pina.model import FeedForward
|
||||
from pina.trainer import Trainer
|
||||
from torch._dynamo.eval_frame import OptimizedModule
|
||||
from pina.graph import KNNGraph
|
||||
|
||||
|
||||
class LabelTensorProblem(AbstractProblem):
|
||||
@@ -28,9 +30,64 @@ class TensorProblem(AbstractProblem):
|
||||
}
|
||||
|
||||
|
||||
x = torch.rand((100, 20, 5))
|
||||
pos = torch.rand((100, 20, 2))
|
||||
output_ = torch.rand((100, 20, 1))
|
||||
input_ = [
|
||||
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
|
||||
for x_, pos_ in zip(x, pos)
|
||||
]
|
||||
|
||||
|
||||
class GraphProblem(AbstractProblem):
|
||||
output_variables = None
|
||||
conditions = {"data": Condition(input=input_, target=output_)}
|
||||
|
||||
|
||||
x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
|
||||
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
|
||||
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
|
||||
input_ = [
|
||||
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
|
||||
for i in range(len(x))
|
||||
]
|
||||
|
||||
|
||||
class GraphProblemLT(AbstractProblem):
|
||||
output_variables = ["u"]
|
||||
input_variables = ["a", "b", "c", "d", "e"]
|
||||
conditions = {"data": Condition(input=input_, target=output_)}
|
||||
|
||||
|
||||
model = FeedForward(2, 1)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lift = torch.nn.Linear(5, 10)
|
||||
self.activation = torch.nn.Tanh()
|
||||
self.output = torch.nn.Linear(10, 1)
|
||||
|
||||
self.conv = GCNConv(10, 10)
|
||||
|
||||
def forward(self, batch):
|
||||
|
||||
x = batch.x
|
||||
edge_index = batch.edge_index
|
||||
for _ in range(1):
|
||||
y = self.lift(x)
|
||||
y = self.activation(y)
|
||||
y = self.conv(y, edge_index)
|
||||
y = self.activation(y)
|
||||
y = self.output(y)
|
||||
return y
|
||||
|
||||
|
||||
graph_model = Model()
|
||||
|
||||
|
||||
def test_constructor():
|
||||
SupervisedSolver(problem=TensorProblem(), model=model)
|
||||
SupervisedSolver(problem=LabelTensorProblem(), model=model)
|
||||
@@ -59,6 +116,24 @@ def test_solver_train(use_lt, batch_size, compile):
|
||||
assert isinstance(solver.model, OptimizedModule)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||
@pytest.mark.parametrize("use_lt", [True, False])
|
||||
def test_solver_train_graph(batch_size, use_lt):
|
||||
problem = GraphProblemLT() if use_lt else GraphProblem()
|
||||
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
max_epochs=2,
|
||||
accelerator="cpu",
|
||||
batch_size=batch_size,
|
||||
train_size=1.0,
|
||||
test_size=0.0,
|
||||
val_size=0.0,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_lt", [True, False])
|
||||
@pytest.mark.parametrize("compile", [True, False])
|
||||
def test_solver_validation(use_lt, compile):
|
||||
@@ -79,6 +154,24 @@ def test_solver_validation(use_lt, compile):
|
||||
assert isinstance(solver.model, OptimizedModule)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||
@pytest.mark.parametrize("use_lt", [True, False])
|
||||
def test_solver_validation_graph(batch_size, use_lt):
|
||||
problem = GraphProblemLT() if use_lt else GraphProblem()
|
||||
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
max_epochs=2,
|
||||
accelerator="cpu",
|
||||
batch_size=batch_size,
|
||||
train_size=0.9,
|
||||
val_size=0.1,
|
||||
test_size=0.0,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_lt", [True, False])
|
||||
@pytest.mark.parametrize("compile", [True, False])
|
||||
def test_solver_test(use_lt, compile):
|
||||
@@ -99,6 +192,24 @@ def test_solver_test(use_lt, compile):
|
||||
assert isinstance(solver.model, OptimizedModule)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||
@pytest.mark.parametrize("use_lt", [True, False])
|
||||
def test_solver_test_graph(batch_size, use_lt):
|
||||
problem = GraphProblemLT() if use_lt else GraphProblem()
|
||||
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
max_epochs=2,
|
||||
accelerator="cpu",
|
||||
batch_size=batch_size,
|
||||
train_size=0.8,
|
||||
val_size=0.1,
|
||||
test_size=0.1,
|
||||
)
|
||||
|
||||
trainer.test()
|
||||
|
||||
|
||||
def test_train_load_restore():
|
||||
dir = "tests/test_solver/tmp/"
|
||||
problem = LabelTensorProblem()
|
||||
|
||||
Reference in New Issue
Block a user