add linear weighting

This commit is contained in:
giovanni
2025-09-05 10:12:09 +02:00
committed by Giovanni Canali
parent 96402baf20
commit ef3542486c
7 changed files with 176 additions and 9 deletions

View File

@@ -17,7 +17,7 @@ class ScalarWeighting(WeightingInterface):
If a single scalar value is provided, it is assigned to all loss
terms. If a dictionary is provided, the keys are the conditions and
the values are the weights. If a condition is not present in the
dictionary, the default value is used.
dictionary, the default value (1) is used.
:type weights: float | int | dict
"""
super().__init__(update_every_n_epochs=1, aggregator="sum")
@@ -29,11 +29,9 @@ class ScalarWeighting(WeightingInterface):
if isinstance(weights, dict):
self.values = weights
self.default_value_weights = 1
elif isinstance(weights, (float, int)):
else:
self.values = {}
self.default_value_weights = weights
else:
raise ValueError
def weights_update(self, losses):
"""