Files
PINA/pina/loss/scalar_weighting.py
Dario Coscia 42ab1a666b Formatting
* Adding black as dev dependency
* Formatting pina code
* Formatting tests
2025-03-19 17:46:36 +01:00

39 lines
1.0 KiB
Python

"""Module for Loss Interface"""
from .weighting_interface import WeightingInterface
from ..utils import check_consistency
class _NoWeighting(WeightingInterface):
def aggregate(self, losses):
return sum(losses.values())
class ScalarWeighting(WeightingInterface):
"""
TODO
"""
def __init__(self, weights):
super().__init__()
check_consistency([weights], (float, dict, int))
if isinstance(weights, (float, int)):
self.default_value_weights = weights
self.weights = {}
else:
self.default_value_weights = 1
self.weights = weights
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict(torch.Tensor) losses: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
return sum(
self.weights.get(condition, self.default_value_weights) * loss
for condition, loss in losses.items()
)