Add test supervised solver for graph based models (#480)

This commit is contained in:
Filippo Olivo
2025-03-10 18:14:22 +01:00
committed by Nicola Demo
parent 4177bfbb50
commit 2ae4a94e49

View File

@@ -1,12 +1,14 @@
import torch
import pytest
from torch._dynamo.eval_frame import OptimizedModule
from torch_geometric.nn import GCNConv
from pina import Condition, LabelTensor
from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem
from pina.solver import SupervisedSolver
from pina.model import FeedForward
from pina.trainer import Trainer
from torch._dynamo.eval_frame import OptimizedModule
from pina.graph import KNNGraph
class LabelTensorProblem(AbstractProblem):
@@ -28,9 +30,64 @@ class TensorProblem(AbstractProblem):
}
x = torch.rand((100, 20, 5))
pos = torch.rand((100, 20, 2))
output_ = torch.rand((100, 20, 1))
input_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x, pos)
]
class GraphProblem(AbstractProblem):
output_variables = None
conditions = {"data": Condition(input=input_, target=output_)}
x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
input_ = [
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
for i in range(len(x))
]
class GraphProblemLT(AbstractProblem):
output_variables = ["u"]
input_variables = ["a", "b", "c", "d", "e"]
conditions = {"data": Condition(input=input_, target=output_)}
model = FeedForward(2, 1)
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lift = torch.nn.Linear(5, 10)
self.activation = torch.nn.Tanh()
self.output = torch.nn.Linear(10, 1)
self.conv = GCNConv(10, 10)
def forward(self, batch):
x = batch.x
edge_index = batch.edge_index
for _ in range(1):
y = self.lift(x)
y = self.activation(y)
y = self.conv(y, edge_index)
y = self.activation(y)
y = self.output(y)
return y
graph_model = Model()
def test_constructor():
SupervisedSolver(problem=TensorProblem(), model=model)
SupervisedSolver(problem=LabelTensorProblem(), model=model)
@@ -59,6 +116,24 @@ def test_solver_train(use_lt, batch_size, compile):
assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_train_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=1.0,
test_size=0.0,
val_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(use_lt, compile):
@@ -79,6 +154,24 @@ def test_solver_validation(use_lt, compile):
assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_validation_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.9,
val_size=0.1,
test_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(use_lt, compile):
@@ -99,6 +192,24 @@ def test_solver_test(use_lt, compile):
assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_test_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.8,
val_size=0.1,
test_size=0.1,
)
trainer.test()
def test_train_load_restore():
dir = "tests/test_solver/tmp/"
problem = LabelTensorProblem()