weighting refactory

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
giovanni
2025-09-01 11:00:14 +02:00
committed by Giovanni Canali
parent c42bdd575c
commit 96402baf20
12 changed files with 214 additions and 388 deletions

View File

@@ -4,22 +4,6 @@ from .weighting_interface import WeightingInterface
from ..utils import check_consistency
class _NoWeighting(WeightingInterface):
"""
Weighting scheme that does not apply any weighting to the losses.
"""
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict losses: The dictionary of losses.
:return: The aggregated losses.
:rtype: torch.Tensor
"""
return sum(losses.values())
class ScalarWeighting(WeightingInterface):
"""
Weighting scheme that assigns a scalar weight to each loss term.
@@ -36,28 +20,42 @@ class ScalarWeighting(WeightingInterface):
dictionary, the default value is used.
:type weights: float | int | dict
"""
super().__init__()
super().__init__(update_every_n_epochs=1, aggregator="sum")
# Check consistency
check_consistency([weights], (float, dict, int))
# Weights initialization
if isinstance(weights, (float, int)):
# Initialization
if isinstance(weights, dict):
self.values = weights
self.default_value_weights = 1
elif isinstance(weights, (float, int)):
self.values = {}
self.default_value_weights = weights
self.weights = {}
else:
self.default_value_weights = 1.0
self.weights = weights
raise ValueError
def aggregate(self, losses):
def weights_update(self, losses):
"""
Aggregate the losses.
Update the weighting scheme based on the given losses.
:param dict losses: The dictionary of losses.
:return: The aggregated losses.
:rtype: torch.Tensor
:return: The updated weights.
:rtype: dict
"""
return sum(
self.weights.get(condition, self.default_value_weights) * loss
for condition, loss in losses.items()
)
return {
condition: self.values.get(condition, self.default_value_weights)
for condition in losses.keys()
}
class _NoWeighting(ScalarWeighting):
"""
Weighting scheme that does not apply any weighting to the losses.
"""
def __init__(self):
"""
Initialization of the :class:`_NoWeighting` class.
"""
super().__init__(weights=1)