add mutual solver-weighting link

This commit is contained in:
giovanni
2025-08-29 19:11:08 +02:00
committed by Giovanni Canali
parent 973d0c05ab
commit bacd7e202a
6 changed files with 62 additions and 76 deletions

View File

@@ -2,64 +2,32 @@ import pytest
from pina import Trainer
from pina.solver import PINN
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import NeuralTangentKernelWeighting
from pina.problem.zoo import Poisson2DSquareProblem
# Initialize problem and model
problem = Poisson2DSquareProblem()
condition_names = problem.conditions.keys()
problem.discretise_domain(10)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_constructor(model, alpha):
NeuralTangentKernelWeighting(model=model, alpha=alpha)
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_constructor(alpha):
NeuralTangentKernelWeighting(alpha=alpha)
@pytest.mark.parametrize("model", [0.5])
def test_wrong_constructor1(model):
# Should fail if alpha is not >= 0
with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model)
NeuralTangentKernelWeighting(alpha=-0.1)
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
1.2,
)
],
)
def test_wrong_constructor2(model, alpha):
# Should fail if alpha is not <= 1
with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model, alpha)
NeuralTangentKernelWeighting(alpha=1.1)
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_train_aggregation(model, alpha):
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
problem.discretise_domain(50)
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_train_aggregation(alpha):
weighting = NeuralTangentKernelWeighting(alpha=alpha)
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train()

View File

@@ -1,16 +1,17 @@
import pytest
import torch
from pina import Trainer
from pina.solver import PINN
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import ScalarWeighting
from pina.problem.zoo import Poisson2DSquareProblem
# Initialize problem and model
problem = Poisson2DSquareProblem()
problem.discretise_domain(50)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
condition_names = problem.conditions.keys()
print(problem.conditions.keys())
@pytest.mark.parametrize(
@@ -19,11 +20,13 @@ print(problem.conditions.keys())
def test_constructor(weights):
ScalarWeighting(weights=weights)
@pytest.mark.parametrize("weights", ["a", [1, 2, 3]])
def test_wrong_constructor(weights):
# Should fail if weights are not a scalar
with pytest.raises(ValueError):
ScalarWeighting(weights=weights)
ScalarWeighting(weights="invalid")
# Should fail if weights are not a dictionary
with pytest.raises(ValueError):
ScalarWeighting(weights=[1, 2, 3])
@pytest.mark.parametrize(
@@ -45,7 +48,6 @@ def test_aggregate(weights):
)
def test_train_aggregation(weights):
weighting = ScalarWeighting(weights=weights)
problem.discretise_domain(50)
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train()