Add test supervised solver for graph based models (#480)

This commit is contained in:
Filippo Olivo
2025-03-10 18:14:22 +01:00
committed by Nicola Demo
parent 4177bfbb50
commit 2ae4a94e49

View File

@@ -1,12 +1,14 @@
import torch import torch
import pytest import pytest
from torch._dynamo.eval_frame import OptimizedModule
from torch_geometric.nn import GCNConv
from pina import Condition, LabelTensor from pina import Condition, LabelTensor
from pina.condition import InputTargetCondition from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem from pina.problem import AbstractProblem
from pina.solver import SupervisedSolver from pina.solver import SupervisedSolver
from pina.model import FeedForward from pina.model import FeedForward
from pina.trainer import Trainer from pina.trainer import Trainer
from torch._dynamo.eval_frame import OptimizedModule from pina.graph import KNNGraph
class LabelTensorProblem(AbstractProblem): class LabelTensorProblem(AbstractProblem):
@@ -28,9 +30,64 @@ class TensorProblem(AbstractProblem):
} }
x = torch.rand((100, 20, 5))
pos = torch.rand((100, 20, 2))
output_ = torch.rand((100, 20, 1))
input_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x, pos)
]
class GraphProblem(AbstractProblem):
output_variables = None
conditions = {"data": Condition(input=input_, target=output_)}
x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
input_ = [
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
for i in range(len(x))
]
class GraphProblemLT(AbstractProblem):
output_variables = ["u"]
input_variables = ["a", "b", "c", "d", "e"]
conditions = {"data": Condition(input=input_, target=output_)}
model = FeedForward(2, 1) model = FeedForward(2, 1)
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lift = torch.nn.Linear(5, 10)
self.activation = torch.nn.Tanh()
self.output = torch.nn.Linear(10, 1)
self.conv = GCNConv(10, 10)
def forward(self, batch):
x = batch.x
edge_index = batch.edge_index
for _ in range(1):
y = self.lift(x)
y = self.activation(y)
y = self.conv(y, edge_index)
y = self.activation(y)
y = self.output(y)
return y
graph_model = Model()
def test_constructor(): def test_constructor():
SupervisedSolver(problem=TensorProblem(), model=model) SupervisedSolver(problem=TensorProblem(), model=model)
SupervisedSolver(problem=LabelTensorProblem(), model=model) SupervisedSolver(problem=LabelTensorProblem(), model=model)
@@ -59,6 +116,24 @@ def test_solver_train(use_lt, batch_size, compile):
assert isinstance(solver.model, OptimizedModule) assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_train_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=1.0,
test_size=0.0,
val_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(use_lt, compile): def test_solver_validation(use_lt, compile):
@@ -79,6 +154,24 @@ def test_solver_validation(use_lt, compile):
assert isinstance(solver.model, OptimizedModule) assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_validation_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.9,
val_size=0.1,
test_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize("compile", [True, False])
def test_solver_test(use_lt, compile): def test_solver_test(use_lt, compile):
@@ -99,6 +192,24 @@ def test_solver_test(use_lt, compile):
assert isinstance(solver.model, OptimizedModule) assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_test_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.8,
val_size=0.1,
test_size=0.1,
)
trainer.test()
def test_train_load_restore(): def test_train_load_restore():
dir = "tests/test_solver/tmp/" dir = "tests/test_solver/tmp/"
problem = LabelTensorProblem() problem = LabelTensorProblem()