Fix Codacy Warnings (#477)

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-10 15:38:45 +01:00
committed by Nicola Demo
parent e3790e049a
commit 4177bfbb50
157 changed files with 3473 additions and 3839 deletions

View File

@@ -12,31 +12,40 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables))
condition_names = problem.conditions.keys()
print(problem.conditions.keys())
@pytest.mark.parametrize("weights",
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
@pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
)
def test_constructor(weights):
ScalarWeighting(weights=weights)
@pytest.mark.parametrize("weights", ['a', [1,2,3]])
@pytest.mark.parametrize("weights", ["a", [1, 2, 3]])
def test_wrong_constructor(weights):
with pytest.raises(ValueError):
ScalarWeighting(weights=weights)
@pytest.mark.parametrize("weights",
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
@pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
)
def test_aggregate(weights):
weighting = ScalarWeighting(weights=weights)
losses = dict(zip(condition_names, [torch.randn(1) for _ in range(len(condition_names))]))
losses = dict(
zip(
condition_names,
[torch.randn(1) for _ in range(len(condition_names))],
)
)
weighting.aggregate(losses=losses)
@pytest.mark.parametrize("weights",
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
@pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
)
def test_train_aggregation(weights):
weighting = ScalarWeighting(weights=weights)
problem.discretise_domain(50)
solver = PINN(
problem=problem,
model=model,
weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator='cpu')
trainer.train()
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train()