This commit is contained in:
Nicola Demo
2024-09-09 10:50:54 +02:00
parent 9d9c2aa23e
commit f0d68b34c7
23 changed files with 480 additions and 229 deletions

View File

@@ -0,0 +1,35 @@
""" Module for Loss Interface """
from .weightning_interface import weightningInterface
class WeightedAggregation(WeightningInterface):
"""
TODO
"""
def __init__(self, aggr='mean', weights=None):
self.aggr = aggr
self.weights = weights
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict(torch.Tensor) input: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
if self.weights:
weighted_losses = {
condition: self.weights[condition] * losses[condition]
for condition in losses
}
else:
weighted_losses = losses
if self.aggr == 'mean':
return sum(weighted_losses.values()) / len(weighted_losses)
elif self.aggr == 'sum':
return sum(weighted_losses.values())
else:
raise ValueError(self.aggr + " is not valid for aggregation.")