add batching support for self-adaptive pinns
This commit is contained in:
committed by
Giovanni Canali
parent
1ed14916f1
commit
6d1d4ef423
@@ -42,9 +42,11 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||
@pytest.mark.parametrize("weight_fn", [torch.nn.Sigmoid(), torch.nn.Tanh()])
|
||||
def test_constructor(problem, weight_fn):
|
||||
|
||||
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SAPINN(model=model, problem=problem, weight_function=1)
|
||||
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
|
||||
|
||||
assert solver.accepted_conditions_types == (
|
||||
InputTargetCondition,
|
||||
@@ -53,26 +55,13 @@ def test_constructor(problem, weight_fn):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||
def test_wrong_batch(problem):
|
||||
with pytest.raises(NotImplementedError):
|
||||
solver = SAPINN(model=model, problem=problem)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
max_epochs=2,
|
||||
accelerator="cpu",
|
||||
batch_size=10,
|
||||
train_size=1.0,
|
||||
val_size=0.0,
|
||||
test_size=0.0,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||
@pytest.mark.parametrize("compile", [True, False])
|
||||
def test_solver_train(problem, compile):
|
||||
solver = SAPINN(model=model, problem=problem)
|
||||
@pytest.mark.parametrize(
|
||||
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
|
||||
)
|
||||
def test_solver_train(problem, compile, loss):
|
||||
solver = SAPINN(model=model, problem=problem, loss=loss)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
max_epochs=2,
|
||||
@@ -95,8 +84,11 @@ def test_solver_train(problem, compile):
|
||||
|
||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||
@pytest.mark.parametrize("compile", [True, False])
|
||||
def test_solver_validation(problem, compile):
|
||||
solver = SAPINN(model=model, problem=problem)
|
||||
@pytest.mark.parametrize(
|
||||
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
|
||||
)
|
||||
def test_solver_validation(problem, compile, loss):
|
||||
solver = SAPINN(model=model, problem=problem, loss=loss)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
max_epochs=2,
|
||||
@@ -119,8 +111,11 @@ def test_solver_validation(problem, compile):
|
||||
|
||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||
@pytest.mark.parametrize("compile", [True, False])
|
||||
def test_solver_test(problem, compile):
|
||||
solver = SAPINN(model=model, problem=problem)
|
||||
@pytest.mark.parametrize(
|
||||
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
|
||||
)
|
||||
def test_solver_test(problem, compile, loss):
|
||||
solver = SAPINN(model=model, problem=problem, loss=loss)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
max_epochs=2,
|
||||
|
||||
Reference in New Issue
Block a user