batching for rbapinns

This commit is contained in:
giovanni
2025-06-16 14:17:54 +02:00
committed by Dario Coscia
parent 3778ef7ee2
commit de47d69fec
2 changed files with 215 additions and 91 deletions

View File

@@ -42,10 +42,14 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize("eta", [1, 0.001])
@pytest.mark.parametrize("gamma", [0.5, 0.9])
def test_constructor(problem, eta, gamma):
with pytest.raises(AssertionError):
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma)
with pytest.raises(ValueError):
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
with pytest.raises(ValueError):
solver = RBAPINN(model=model, problem=problem, eta=-0.1)
assert solver.accepted_conditions_types == (
InputTargetCondition,
InputEquationCondition,
@@ -54,30 +58,18 @@ def test_constructor(problem, eta, gamma):
@pytest.mark.parametrize("problem", [problem, inverse_problem])
def test_wrong_batch(problem):
with pytest.raises(NotImplementedError):
solver = RBAPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=10,
train_size=1.0,
val_size=0.0,
test_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_train(problem, compile):
solver = RBAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_train(problem, batch_size, loss, compile):
solver = RBAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=None,
batch_size=batch_size,
train_size=1.0,
val_size=0.0,
test_size=0.0,
@@ -89,14 +81,18 @@ def test_solver_train(problem, compile):
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(problem, compile):
solver = RBAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_validation(problem, batch_size, loss, compile):
solver = RBAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=None,
batch_size=batch_size,
train_size=0.9,
val_size=0.1,
test_size=0.0,
@@ -108,14 +104,18 @@ def test_solver_validation(problem, compile):
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(problem, compile):
solver = RBAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_test(problem, batch_size, loss, compile):
solver = RBAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=None,
batch_size=batch_size,
train_size=0.7,
val_size=0.2,
test_size=0.1,