Files
PINA/tests/test_solver/test_causal_pinn.py
Dario Coscia 6dd7bd2825 Refactoring solvers (#541)
* Refactoring solvers

* Simplify logic compile
* Improve and update doc
* Create SupervisedSolverInterface
* Specialize SupervisedSolver and ReducedOrderModelSolver
* Create EnsembleSolverInterface + EnsembleSupervisedSolver
* Create tests ensemble solvers

* formatter

* codacy

* fix issues + speedup test
2025-04-23 18:53:30 +02:00

161 lines
4.5 KiB
Python

import torch
import pytest
from pina import LabelTensor, Condition
from pina.problem import SpatialProblem
from pina.solver import CausalPINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.problem.zoo import DiffusionReactionProblem
from pina.condition import (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
from torch._dynamo.eval_frame import OptimizedModule
class DummySpatialProblem(SpatialProblem):
"""
A mock spatial problem for testing purposes.
"""
output_variables = ["u"]
conditions = {}
spatial_domain = None
# define problems
problem = DiffusionReactionProblem()
problem.discretise_domain(10)
# add input-output condition to test supervised learning
input_pts = torch.rand(10, len(problem.input_variables))
input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(10, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)
# define model
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize("problem", [problem])
@pytest.mark.parametrize("eps", [100, 100.1])
def test_constructor(problem, eps):
with pytest.raises(ValueError):
CausalPINN(model=model, problem=DummySpatialProblem())
solver = CausalPINN(model=model, problem=problem, eps=eps)
assert solver.accepted_conditions_types == (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
@pytest.mark.parametrize("problem", [problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_train(problem, batch_size, compile):
solver = CausalPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=1.0,
val_size=0.0,
test_size=0.0,
compile=compile,
)
trainer.train()
if trainer.compile:
assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("problem", [problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(problem, batch_size, compile):
solver = CausalPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.9,
val_size=0.1,
test_size=0.0,
compile=compile,
)
trainer.train()
if trainer.compile:
assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("problem", [problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(problem, batch_size, compile):
solver = CausalPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.7,
val_size=0.2,
test_size=0.1,
compile=compile,
)
trainer.test()
if trainer.compile:
assert isinstance(solver.model, OptimizedModule)
@pytest.mark.parametrize("problem", [problem])
def test_train_load_restore(problem):
dir = "tests/test_solver/tmp"
problem = problem
solver = CausalPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=5,
accelerator="cpu",
batch_size=None,
train_size=0.7,
val_size=0.2,
test_size=0.1,
default_root_dir=dir,
)
trainer.train()
# restore
new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
new_trainer.train(
ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/"
+ "epoch=4-step=5.ckpt"
)
# loading
new_solver = CausalPINN.load_from_checkpoint(
f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt",
problem=problem,
model=model,
)
test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables)
assert new_solver.forward(test_pts).shape == (20, 1)
assert new_solver.forward(test_pts).shape == (
solver.forward(test_pts).shape
)
torch.testing.assert_close(
new_solver.forward(test_pts), solver.forward(test_pts)
)
# rm directories
import shutil
shutil.rmtree("tests/test_solver/tmp")