weighting refactory
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Giovanni Canali
parent
c42bdd575c
commit
96402baf20
@@ -1,164 +0,0 @@
|
||||
import pytest
|
||||
import math
|
||||
from pina.solver import PINN
|
||||
from pina.loss import ScalarWeighting
|
||||
from pina.trainer import Trainer
|
||||
from pina.model import FeedForward
|
||||
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||
from pina.callback import LinearWeightUpdate
|
||||
|
||||
|
||||
# Define the problem
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(50, "grid")
|
||||
cond_name = list(poisson_problem.conditions.keys())[0]
|
||||
|
||||
# Define the model
|
||||
model = FeedForward(
|
||||
input_dimensions=len(poisson_problem.input_variables),
|
||||
output_dimensions=len(poisson_problem.output_variables),
|
||||
layers=[32, 32],
|
||||
)
|
||||
|
||||
# Define the weighting schema
|
||||
weights_dict = {key: 1 for key in poisson_problem.conditions.keys()}
|
||||
weighting = ScalarWeighting(weights=weights_dict)
|
||||
|
||||
# Define the solver
|
||||
solver = PINN(problem=poisson_problem, model=model, weighting=weighting)
|
||||
|
||||
# Value used for testing
|
||||
epochs = 10
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_value", [1, 5.5])
|
||||
@pytest.mark.parametrize("target_value", [10, 25.5])
|
||||
def test_constructor(initial_value, target_value):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=initial_value,
|
||||
target_value=target_value,
|
||||
)
|
||||
|
||||
# Target_epoch must be int
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=10.0,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
|
||||
# Condition_name must be str
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=100,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
|
||||
# Initial_value must be float or int
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value="0",
|
||||
target_value=1,
|
||||
)
|
||||
|
||||
# Target_value must be float or int
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_value, target_value", [(1, 10), (10, 1)])
|
||||
def test_training(initial_value, target_value):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=initial_value,
|
||||
target_value=target_value,
|
||||
)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=epochs,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Check that the final weight value matches the target value
|
||||
final_value = solver.weighting.weights[cond_name]
|
||||
assert math.isclose(final_value, target_value)
|
||||
|
||||
# Target_epoch must be greater than 0
|
||||
with pytest.raises(ValueError):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=0,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=5,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Target_epoch must be less than or equal to max_epochs
|
||||
with pytest.raises(ValueError):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=epochs - 1,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Condition_name must be a problem condition
|
||||
with pytest.raises(ValueError):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name="not_a_condition",
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=epochs,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Weighting schema must be ScalarWeighting
|
||||
with pytest.raises(ValueError):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
unweighted_solver = PINN(problem=poisson_problem, model=model)
|
||||
trainer = Trainer(
|
||||
solver=unweighted_solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=epochs,
|
||||
)
|
||||
trainer.train()
|
||||
@@ -12,22 +12,42 @@ problem.discretise_domain(10)
|
||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [1, 10, 100, 1000])
|
||||
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||
def test_constructor(alpha):
|
||||
NeuralTangentKernelWeighting(alpha=alpha)
|
||||
def test_constructor(update_every_n_epochs, alpha):
|
||||
NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=alpha
|
||||
)
|
||||
|
||||
# Should fail if alpha is not >= 0
|
||||
with pytest.raises(ValueError):
|
||||
NeuralTangentKernelWeighting(alpha=-0.1)
|
||||
NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=-0.1
|
||||
)
|
||||
|
||||
# Should fail if alpha is not <= 1
|
||||
with pytest.raises(ValueError):
|
||||
NeuralTangentKernelWeighting(alpha=1.1)
|
||||
|
||||
# Should fail if update_every_n_epochs is not an integer
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=1.5)
|
||||
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=0)
|
||||
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [1, 3])
|
||||
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||
def test_train_aggregation(alpha):
|
||||
weighting = NeuralTangentKernelWeighting(alpha=alpha)
|
||||
def test_train_aggregation(update_every_n_epochs, alpha):
|
||||
weighting = NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=alpha
|
||||
)
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||
trainer.train()
|
||||
|
||||
@@ -29,20 +29,6 @@ def test_constructor(weights):
|
||||
ScalarWeighting(weights=[1, 2, 3])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
def test_aggregate(weights):
|
||||
weighting = ScalarWeighting(weights=weights)
|
||||
losses = dict(
|
||||
zip(
|
||||
condition_names,
|
||||
[torch.randn(1) for _ in range(len(condition_names))],
|
||||
)
|
||||
)
|
||||
weighting.aggregate(losses=losses)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
|
||||
@@ -12,26 +12,28 @@ problem.discretise_domain(10)
|
||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [10, 100, 1000])
|
||||
def test_constructor(k):
|
||||
SelfAdaptiveWeighting(k=k)
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [10, 100, 1000])
|
||||
def test_constructor(update_every_n_epochs):
|
||||
SelfAdaptiveWeighting(update_every_n_epochs=update_every_n_epochs)
|
||||
|
||||
# Should fail if k is not an integer
|
||||
# Should fail if update_every_n_epochs is not an integer
|
||||
with pytest.raises(AssertionError):
|
||||
SelfAdaptiveWeighting(k=1.5)
|
||||
SelfAdaptiveWeighting(update_every_n_epochs=1.5)
|
||||
|
||||
# Should fail if k is not > 0
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
SelfAdaptiveWeighting(k=0)
|
||||
SelfAdaptiveWeighting(update_every_n_epochs=0)
|
||||
|
||||
# Should fail if k is not > 0
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
SelfAdaptiveWeighting(k=-3)
|
||||
SelfAdaptiveWeighting(update_every_n_epochs=-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [2, 3])
|
||||
def test_train_aggregation(k):
|
||||
weighting = SelfAdaptiveWeighting(k=k)
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [1, 3])
|
||||
def test_train_aggregation(update_every_n_epochs):
|
||||
weighting = SelfAdaptiveWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs
|
||||
)
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user