weighting refactory
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Giovanni Canali
parent
c42bdd575c
commit
96402baf20
@@ -4,22 +4,6 @@ from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class _NoWeighting(WeightingInterface):
|
||||
"""
|
||||
Weighting scheme that does not apply any weighting to the losses.
|
||||
"""
|
||||
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The aggregated losses.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return sum(losses.values())
|
||||
|
||||
|
||||
class ScalarWeighting(WeightingInterface):
|
||||
"""
|
||||
Weighting scheme that assigns a scalar weight to each loss term.
|
||||
@@ -36,28 +20,42 @@ class ScalarWeighting(WeightingInterface):
|
||||
dictionary, the default value is used.
|
||||
:type weights: float | int | dict
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(update_every_n_epochs=1, aggregator="sum")
|
||||
|
||||
# Check consistency
|
||||
check_consistency([weights], (float, dict, int))
|
||||
|
||||
# Weights initialization
|
||||
if isinstance(weights, (float, int)):
|
||||
# Initialization
|
||||
if isinstance(weights, dict):
|
||||
self.values = weights
|
||||
self.default_value_weights = 1
|
||||
elif isinstance(weights, (float, int)):
|
||||
self.values = {}
|
||||
self.default_value_weights = weights
|
||||
self.weights = {}
|
||||
else:
|
||||
self.default_value_weights = 1.0
|
||||
self.weights = weights
|
||||
raise ValueError
|
||||
|
||||
def aggregate(self, losses):
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
Update the weighting scheme based on the given losses.
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The aggregated losses.
|
||||
:rtype: torch.Tensor
|
||||
:return: The updated weights.
|
||||
:rtype: dict
|
||||
"""
|
||||
return sum(
|
||||
self.weights.get(condition, self.default_value_weights) * loss
|
||||
for condition, loss in losses.items()
|
||||
)
|
||||
return {
|
||||
condition: self.values.get(condition, self.default_value_weights)
|
||||
for condition in losses.keys()
|
||||
}
|
||||
|
||||
|
||||
class _NoWeighting(ScalarWeighting):
|
||||
"""
|
||||
Weighting scheme that does not apply any weighting to the losses.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialization of the :class:`_NoWeighting` class.
|
||||
"""
|
||||
super().__init__(weights=1)
|
||||
|
||||
Reference in New Issue
Block a user