weighting refactory

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
giovanni
2025-09-01 11:00:14 +02:00
committed by Giovanni Canali
parent c42bdd575c
commit 96402baf20
12 changed files with 214 additions and 388 deletions

View File

@@ -29,20 +29,6 @@ def test_constructor(weights):
ScalarWeighting(weights=[1, 2, 3])
@pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
)
def test_aggregate(weights):
weighting = ScalarWeighting(weights=weights)
losses = dict(
zip(
condition_names,
[torch.randn(1) for _ in range(len(condition_names))],
)
)
weighting.aggregate(losses=losses)
@pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
)