weighting refactory
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Giovanni Canali
parent
c42bdd575c
commit
96402baf20
@@ -29,20 +29,6 @@ def test_constructor(weights):
|
||||
ScalarWeighting(weights=[1, 2, 3])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
def test_aggregate(weights):
|
||||
weighting = ScalarWeighting(weights=weights)
|
||||
losses = dict(
|
||||
zip(
|
||||
condition_names,
|
||||
[torch.randn(1) for _ in range(len(condition_names))],
|
||||
)
|
||||
)
|
||||
weighting.aggregate(losses=losses)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user