Fix Codacy Warnings (#477)
--------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
e3790e049a
commit
4177bfbb50
@@ -12,31 +12,40 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
condition_names = problem.conditions.keys()
|
||||
print(problem.conditions.keys())
|
||||
|
||||
@pytest.mark.parametrize("weights",
|
||||
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
def test_constructor(weights):
|
||||
ScalarWeighting(weights=weights)
|
||||
|
||||
@pytest.mark.parametrize("weights", ['a', [1,2,3]])
|
||||
|
||||
@pytest.mark.parametrize("weights", ["a", [1, 2, 3]])
|
||||
def test_wrong_constructor(weights):
|
||||
with pytest.raises(ValueError):
|
||||
ScalarWeighting(weights=weights)
|
||||
|
||||
@pytest.mark.parametrize("weights",
|
||||
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
def test_aggregate(weights):
|
||||
weighting = ScalarWeighting(weights=weights)
|
||||
losses = dict(zip(condition_names, [torch.randn(1) for _ in range(len(condition_names))]))
|
||||
losses = dict(
|
||||
zip(
|
||||
condition_names,
|
||||
[torch.randn(1) for _ in range(len(condition_names))],
|
||||
)
|
||||
)
|
||||
weighting.aggregate(losses=losses)
|
||||
|
||||
@pytest.mark.parametrize("weights",
|
||||
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
def test_train_aggregation(weights):
|
||||
weighting = ScalarWeighting(weights=weights)
|
||||
problem.discretise_domain(50)
|
||||
solver = PINN(
|
||||
problem=problem,
|
||||
model=model,
|
||||
weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator='cpu')
|
||||
trainer.train()
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user