weighting refactory
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Giovanni Canali
parent
c42bdd575c
commit
96402baf20
@@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
from ..utils import check_consistency, in_range
|
||||
|
||||
|
||||
class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
@@ -20,32 +20,34 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=0.5):
|
||||
def __init__(self, update_every_n_epochs=1, alpha=0.5):
|
||||
"""
|
||||
Initialization of the :class:`NeuralTangentKernelWeighting` class.
|
||||
|
||||
:param int update_every_n_epochs: The number of training epochs between
|
||||
weight updates. If set to 1, the weights are updated at every epoch.
|
||||
Default is 1.
|
||||
:param float alpha: The alpha parameter.
|
||||
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(update_every_n_epochs=update_every_n_epochs)
|
||||
|
||||
# Check consistency
|
||||
check_consistency(alpha, float)
|
||||
if alpha < 0 or alpha > 1:
|
||||
raise ValueError("alpha should be a value between 0 and 1")
|
||||
if not in_range(alpha, [0, 1], strict=False):
|
||||
raise ValueError("alpha must be in range (0, 1).")
|
||||
|
||||
# Initialize parameters
|
||||
self.alpha = alpha
|
||||
self.weights = {}
|
||||
self.default_value_weights = 1.0
|
||||
|
||||
def aggregate(self, losses):
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
Weight the losses according to the Neural Tangent Kernel algorithm.
|
||||
Update the weighting scheme based on the given losses.
|
||||
|
||||
:param dict(torch.Tensor) input: The dictionary of losses.
|
||||
:return: The aggregation of the losses. It should be a scalar Tensor.
|
||||
:rtype: torch.Tensor
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The updated weights.
|
||||
:rtype: dict
|
||||
"""
|
||||
# Define a dictionary to store the norms of the gradients
|
||||
losses_norm = {}
|
||||
@@ -60,14 +62,10 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
|
||||
# Update the weights
|
||||
self.weights = {
|
||||
condition: self.alpha
|
||||
* self.weights.get(condition, self.default_value_weights)
|
||||
condition: self.alpha * self.weights.get(condition, 1)
|
||||
+ (1 - self.alpha)
|
||||
* losses_norm[condition]
|
||||
/ sum(losses_norm.values())
|
||||
for condition in losses
|
||||
}
|
||||
|
||||
return sum(
|
||||
self.weights[condition] * loss for condition, loss in losses.items()
|
||||
)
|
||||
return self.weights
|
||||
|
||||
Reference in New Issue
Block a user