add linear weighting
This commit is contained in:
committed by
Giovanni Canali
parent
96402baf20
commit
ef3542486c
95
tests/test_weighting/test_linear_weighting.py
Normal file
95
tests/test_weighting/test_linear_weighting.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import math
|
||||
import pytest
|
||||
from pina import Trainer
|
||||
from pina.solver import PINN
|
||||
from pina.model import FeedForward
|
||||
from pina.loss import LinearWeighting
|
||||
from pina.problem.zoo import Poisson2DSquareProblem
|
||||
|
||||
|
||||
# Initialize problem and model
|
||||
problem = Poisson2DSquareProblem()
|
||||
problem.discretise_domain(10)
|
||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
|
||||
# Weights for testing
|
||||
init_weight_1 = {cond: 3 for cond in problem.conditions.keys()}
|
||||
init_weight_2 = {cond: 4 for cond in problem.conditions.keys()}
|
||||
final_weight_1 = {cond: 1 for cond in problem.conditions.keys()}
|
||||
final_weight_2 = {cond: 5 for cond in problem.conditions.keys()}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2])
|
||||
@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2])
|
||||
@pytest.mark.parametrize("target_epoch", [5, 10])
|
||||
def test_constructor(initial_weights, final_weights, target_epoch):
|
||||
LinearWeighting(
|
||||
initial_weights=initial_weights,
|
||||
final_weights=final_weights,
|
||||
target_epoch=target_epoch,
|
||||
)
|
||||
|
||||
# Should fail if initial_weights is not a dictionary
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeighting(
|
||||
initial_weights=[1, 1, 1],
|
||||
final_weights=final_weights,
|
||||
target_epoch=target_epoch,
|
||||
)
|
||||
|
||||
# Should fail if final_weights is not a dictionary
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeighting(
|
||||
initial_weights=initial_weights,
|
||||
final_weights=[1, 1, 1],
|
||||
target_epoch=target_epoch,
|
||||
)
|
||||
|
||||
# Should fail if target_epoch is not an integer
|
||||
with pytest.raises(AssertionError):
|
||||
LinearWeighting(
|
||||
initial_weights=initial_weights,
|
||||
final_weights=final_weights,
|
||||
target_epoch=1.5,
|
||||
)
|
||||
|
||||
# Should fail if target_epoch is not positive
|
||||
with pytest.raises(AssertionError):
|
||||
LinearWeighting(
|
||||
initial_weights=initial_weights,
|
||||
final_weights=final_weights,
|
||||
target_epoch=0,
|
||||
)
|
||||
|
||||
# Should fail if dictionary keys do not match
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeighting(
|
||||
initial_weights={list(initial_weights.keys())[0]: 1},
|
||||
final_weights=final_weights,
|
||||
target_epoch=target_epoch,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2])
|
||||
@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2])
|
||||
@pytest.mark.parametrize("target_epoch", [5, 10])
|
||||
def test_train_aggregation(initial_weights, final_weights, target_epoch):
|
||||
weighting = LinearWeighting(
|
||||
initial_weights=initial_weights,
|
||||
final_weights=final_weights,
|
||||
target_epoch=target_epoch,
|
||||
)
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=target_epoch, accelerator="cpu")
|
||||
trainer.train()
|
||||
|
||||
# Check that weights are updated correctly
|
||||
assert all(
|
||||
math.isclose(
|
||||
weighting.last_saved_weights()[cond],
|
||||
final_weights[cond],
|
||||
rel_tol=1e-5,
|
||||
abs_tol=1e-8,
|
||||
)
|
||||
for cond in final_weights.keys()
|
||||
)
|
||||
Reference in New Issue
Block a user