add mutual solver-weighting link
This commit is contained in:
committed by
Giovanni Canali
parent
973d0c05ab
commit
bacd7e202a
@@ -2,64 +2,32 @@ import pytest
|
||||
from pina import Trainer
|
||||
from pina.solver import PINN
|
||||
from pina.model import FeedForward
|
||||
from pina.problem.zoo import Poisson2DSquareProblem
|
||||
from pina.loss import NeuralTangentKernelWeighting
|
||||
from pina.problem.zoo import Poisson2DSquareProblem
|
||||
|
||||
|
||||
# Initialize problem and model
|
||||
problem = Poisson2DSquareProblem()
|
||||
condition_names = problem.conditions.keys()
|
||||
problem.discretise_domain(10)
|
||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,alpha",
|
||||
[
|
||||
(
|
||||
FeedForward(
|
||||
len(problem.input_variables), len(problem.output_variables)
|
||||
),
|
||||
0.5,
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_constructor(model, alpha):
|
||||
NeuralTangentKernelWeighting(model=model, alpha=alpha)
|
||||
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||
def test_constructor(alpha):
|
||||
NeuralTangentKernelWeighting(alpha=alpha)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [0.5])
|
||||
def test_wrong_constructor1(model):
|
||||
# Should fail if alpha is not >= 0
|
||||
with pytest.raises(ValueError):
|
||||
NeuralTangentKernelWeighting(model)
|
||||
NeuralTangentKernelWeighting(alpha=-0.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,alpha",
|
||||
[
|
||||
(
|
||||
FeedForward(
|
||||
len(problem.input_variables), len(problem.output_variables)
|
||||
),
|
||||
1.2,
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_wrong_constructor2(model, alpha):
|
||||
# Should fail if alpha is not <= 1
|
||||
with pytest.raises(ValueError):
|
||||
NeuralTangentKernelWeighting(model, alpha)
|
||||
NeuralTangentKernelWeighting(alpha=1.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,alpha",
|
||||
[
|
||||
(
|
||||
FeedForward(
|
||||
len(problem.input_variables), len(problem.output_variables)
|
||||
),
|
||||
0.5,
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_train_aggregation(model, alpha):
|
||||
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
|
||||
problem.discretise_domain(50)
|
||||
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||
def test_train_aggregation(alpha):
|
||||
weighting = NeuralTangentKernelWeighting(alpha=alpha)
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||
trainer.train()
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pina import Trainer
|
||||
from pina.solver import PINN
|
||||
from pina.model import FeedForward
|
||||
from pina.problem.zoo import Poisson2DSquareProblem
|
||||
from pina.loss import ScalarWeighting
|
||||
from pina.problem.zoo import Poisson2DSquareProblem
|
||||
|
||||
|
||||
# Initialize problem and model
|
||||
problem = Poisson2DSquareProblem()
|
||||
problem.discretise_domain(50)
|
||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
condition_names = problem.conditions.keys()
|
||||
print(problem.conditions.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -19,11 +20,13 @@ print(problem.conditions.keys())
|
||||
def test_constructor(weights):
|
||||
ScalarWeighting(weights=weights)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("weights", ["a", [1, 2, 3]])
|
||||
def test_wrong_constructor(weights):
|
||||
# Should fail if weights are not a scalar
|
||||
with pytest.raises(ValueError):
|
||||
ScalarWeighting(weights=weights)
|
||||
ScalarWeighting(weights="invalid")
|
||||
|
||||
# Should fail if weights are not a dictionary
|
||||
with pytest.raises(ValueError):
|
||||
ScalarWeighting(weights=[1, 2, 3])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -45,7 +48,6 @@ def test_aggregate(weights):
|
||||
)
|
||||
def test_train_aggregation(weights):
|
||||
weighting = ScalarWeighting(weights=weights)
|
||||
problem.discretise_domain(50)
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||
trainer.train()
|
||||
Reference in New Issue
Block a user