weighting refactory
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Giovanni Canali
parent
c42bdd575c
commit
96402baf20
@@ -12,22 +12,42 @@ problem.discretise_domain(10)
|
||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [1, 10, 100, 1000])
|
||||
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||
def test_constructor(alpha):
|
||||
NeuralTangentKernelWeighting(alpha=alpha)
|
||||
def test_constructor(update_every_n_epochs, alpha):
|
||||
NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=alpha
|
||||
)
|
||||
|
||||
# Should fail if alpha is not >= 0
|
||||
with pytest.raises(ValueError):
|
||||
NeuralTangentKernelWeighting(alpha=-0.1)
|
||||
NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=-0.1
|
||||
)
|
||||
|
||||
# Should fail if alpha is not <= 1
|
||||
with pytest.raises(ValueError):
|
||||
NeuralTangentKernelWeighting(alpha=1.1)
|
||||
|
||||
# Should fail if update_every_n_epochs is not an integer
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=1.5)
|
||||
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=0)
|
||||
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [1, 3])
|
||||
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||
def test_train_aggregation(alpha):
|
||||
weighting = NeuralTangentKernelWeighting(alpha=alpha)
|
||||
def test_train_aggregation(update_every_n_epochs, alpha):
|
||||
weighting = NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=alpha
|
||||
)
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user