add linear weighting
This commit is contained in:
committed by
Giovanni Canali
parent
96402baf20
commit
ef3542486c
@@ -17,7 +17,7 @@ class ScalarWeighting(WeightingInterface):
|
||||
If a single scalar value is provided, it is assigned to all loss
|
||||
terms. If a dictionary is provided, the keys are the conditions and
|
||||
the values are the weights. If a condition is not present in the
|
||||
dictionary, the default value is used.
|
||||
dictionary, the default value (1) is used.
|
||||
:type weights: float | int | dict
|
||||
"""
|
||||
super().__init__(update_every_n_epochs=1, aggregator="sum")
|
||||
@@ -29,11 +29,9 @@ class ScalarWeighting(WeightingInterface):
|
||||
if isinstance(weights, dict):
|
||||
self.values = weights
|
||||
self.default_value_weights = 1
|
||||
elif isinstance(weights, (float, int)):
|
||||
else:
|
||||
self.values = {}
|
||||
self.default_value_weights = weights
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user