add batching support for self-adaptive pinns

This commit is contained in:
Giovanni Canali
2025-07-22 14:45:43 +02:00
committed by Giovanni Canali
parent 1ed14916f1
commit 6d1d4ef423
2 changed files with 208 additions and 140 deletions

View File

@@ -42,9 +42,11 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("weight_fn", [torch.nn.Sigmoid(), torch.nn.Tanh()])
def test_constructor(problem, weight_fn):
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
with pytest.raises(ValueError):
SAPINN(model=model, problem=problem, weight_function=1)
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
assert solver.accepted_conditions_types == (
InputTargetCondition,
@@ -53,26 +55,13 @@ def test_constructor(problem, weight_fn):
)
@pytest.mark.parametrize("problem", [problem, inverse_problem])
def test_wrong_batch(problem):
with pytest.raises(NotImplementedError):
solver = SAPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=10,
train_size=1.0,
val_size=0.0,
test_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_train(problem, compile):
solver = SAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_train(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,
@@ -95,8 +84,11 @@ def test_solver_train(problem, compile):
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(problem, compile):
solver = SAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_validation(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,
@@ -119,8 +111,11 @@ def test_solver_validation(problem, compile):
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(problem, compile):
solver = SAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_test(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,